Tuesday 24 November 2020

Intersection of two sorted arrays

 In an interview, I was asked to find the intersection of two sorted arrays. My initial approach was to iterate over each element in first array and use binary search for it in the second array and handle the second array similarly. This solution has complexity of O(N * lg N). It also has one problem that when an element is common, it is found twice: when searched in second from iteration over first and vice versa. This can be easily avoided by searching in the output.

A simpler solution is to just use two pointers.

List<Integer> intersection(int[] a, int[] b){
    List<Integer> result = new ArrayList<>();
    
    int i = 0;
    int j = 0;
    for(; i< a.length &&  j < b.length;){
      if(a[i] == b[j]){
        result.add(a[i]);
        i++;
        j++;
      }else if (a[i] < b[j]){
        i++;       
       }else {
        j++;
       }
    }
    return result;
  }

This solution works in linear time. The solution can be further improved for the average case keeping the worst case complexity same.

Consider a large array as follows:

a = [10,11,12,13,14,15,16,17,...10000,1000000]

Now let's consider the case of finding intersection with a small array as follows:

b = [10000,1000000]

After we initialize the two pointers to the beginning of the arrays,  we keep moving in the first array for a long time. We can improve this by jumping to the next element higher or equal to the element being compared it. We could use binary search for that.

List<Integer> intersection(int[] a, int[] b){
    List<Integer> result = new ArrayList<>();
    
    int i = 0;
    int j = 0;
    for(; i< a.length &&  j < b.length;){
      if(a[i] == b[j]){
        result.add(a[i]);
        i++;
        j++;
      }else if (a[i] < b[j]){
        //i++;
        int low = i + 1;
        int high = a.length - 1;
        int mid = low + (high - low) / 2;
        while(low < high){
          if(a[mid] == b[j]){
            i = mid;
            break;
          }else if(a[mid] > b[j]){
            high = mid - 1;
          }else{
            low = mid + 1;
          }
          mid = low + (high - low) / 2;
        }
        if (low >= high){
          i = low;
        }
        
       }else {
        //j++;
        
        int low = j + 1;
        int high = b.length - 1;
        int mid = low + (high - low) / 2;
        while(low < high){
          if(b[mid] == a[i]){
            j = mid;
            break;
          }else if(b[mid] > a[i]){
            high = mid - 1;
          }else{
            low = mid + 1;
          }
          mid = low + (high - low) / 2;
        }
        if (low >= high){
          j = low;
        }
      }
    }
    return result;
  }

This improves the average case performance, especially for skewed inputs as mentioned earlier.

No comments: