[力扣]4.寻找两个正序数组的中位数

二分查找应用

Posted by CloudingYu on April 13, 2025

题目描述

给定两个大小分别为 mn 的正序(从小到大)数组 nums1nums2。请你找出并返回这两个正序数组的 中位数

算法的时间复杂度应该为 O(log (m+n))

示例

示例 1:

  • 输入
    1
    
    nums1 = [1,3], nums2 = [2]
    
  • 输出
    1
    
    2.00000
    
  • 解释:合并数组 = [1,2,3] ,中位数 2

示例 2:

  • 输入
    1
    
    nums1 = [1,2], nums2 = [3,4]
    
  • 输出
    1
    
    2.50000
    
  • 解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

问题分析

这道题的难点在于要求时间复杂度为 $O\left(\log\left(m+n\right)\right)$。最直观的解法是合并两个数组然后找中位数,但这样的时间复杂度是 $O\left(m+n\right)$,不满足题目要求。

要达到 $O\left(\log\left(m+n\right)\right)$ 的时间复杂度,我们需要使用二分查找的思想。事实上,我们可以将问题转化为寻找两个数组中第 $k$ 小的元素的问题。

解决思路

I. 问题转化

首先将「找中位数」的问题转化为「找第 $k$ 小数」的问题:

  • 如果合并后数组长度为奇数,中位数是第 $\dfrac{m+n}{2}+1$ 小的元素

  • 如果合并后数组长度为偶数,中位数是第 $\dfrac{m+n}{2}$ 小和第 $\dfrac{m+n}{2}+1$ 小的元素的平均值

II. 二分策略

核心思想是在两个有序数组中找到一个分割线,使得:

1
2
3
nums1: [a[1], a[2], a[3], ..., a[i-1] | a[i], a[i+1], ..., a[m]]

nums2: [b[1], b[2], ..., b[j-1] | b[j], b[j+1], ..., b[n]]
  1. 分割线左边的所有元素 $\leq$ 分割线右边的所有元素
  2. 分割线左边的元素个数 $=$ 分割线右边的元素个数(或比右边多一个)

为了实现上述目标,我们需要:

  • 确保短数组在前,长数组在后(便于处理边界情况)
  • 在较短的数组上进行二分查找,寻找合适的分割位置 $\mathrm{i}$
  • 根据 $\mathrm{i}$ 计算出较长数组的分割位置 $\mathrm{j}$ ,满足 $\mathrm{i} + \mathrm{j} = \dfrac{m+n+1}{2}$

III. 终止条件

当我们找到满足以下条件的分割位置时,问题解决: $$\mathrm{maxLeft1} \leq \mathrm{minRight2}$$ $$\mathrm{maxLeft2} \leq \mathrm{minRight1}$$ 这表示分割线左边的所有元素都小于或等于右边的所有元素。

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {

        if (nums1.size() > nums2.size())
            return findMedianSortedArrays(nums2, nums1);
        
        if (nums1.empty())
        {
            int n = nums2.size();
            if (n % 2 == 0)
                return (nums2[n/2-1] + nums2[n/2]) / 2.0;
            else
                return nums2[n/2];
        }
        
        int m = nums1.size();
        int n = nums2.size();
        int left = 0, right = m;
        
        while (left <= right) 
        {
            int i = (left + right) / 2;
            int j = (m + n + 1) / 2 - i;
            
            int maxLeft1 = (i == 0) ? -1e6-1 : nums1[i-1];
            int minRight1 = (i == m) ? 1e6+1 : nums1[i];
            int maxLeft2 = (j == 0) ? -1e6-1 : nums2[j-1];
            int minRight2 = (j == n) ? 1e6+1 : nums2[j];
            
            if (maxLeft1 <= minRight2 && maxLeft2 <= minRight1) 
            {
                if ((m + n) % 2 != 0) 
                    return max(maxLeft1, maxLeft2);
                else
                    return (max(maxLeft1, maxLeft2) + min(minRight1, minRight2)) / 2.0;
            } 
            else if (maxLeft1 > minRight2)
                right = i - 1;

            else
                left = i + 1;
        }
        return 0;
    }
};