- You are given an array
arr
of integers. - You need to count the number of triplets
(i, j, k)
such that:- The XOR of the subarray from
arr[i]
toarr[j-1]
is equal to the XOR of the subarray fromarr[j]
toarr[k]
.
- The XOR of the subarray from
Mathematically, the condition is:
arr[i]⊕arr[i+1]⊕...⊕arr[j−1]=arr[j]⊕arr[j+1]⊕...⊕arr[k]
If we define:
XOR(i,j−1)=arr[i]⊕arr[i+1]⊕...⊕arr[j−1]
XOR(j,k)=arr[j]⊕arr[j+1]⊕...⊕arr[k]
then the condition simplifies to:
XOR(i,k)=0
and...
if XOR(i,k)=0
then
XOR(i,j−1)=XOR(j,k) // because a xor a == 0
Anothother thought, Expressing XOR(i, k) in Terms of Prefix XOR
We need to compute:
XOR(i,k)=arr[i]⊕arr[i+1]⊕...⊕arr[k]
Using the prefix XOR definition:
prefix[k+1]=arr[0]⊕arr[1]⊕...⊕arr[k] // this will be XOR(0, k)
prefix[i]=arr[0]⊕arr[1]⊕...⊕arr[i−1] // this will be XOR(0, i)
If we take:
prefix[k+1]⊕prefix[i]
We get:
(arr[0]⊕arr[1]⊕...⊕arr[k])⊕(arr[0]⊕arr[1]⊕...⊕arr[i−1])
Since everything before index i cancels out, we are left with:
arr[i]⊕arr[i+1]⊕...⊕arr[k]=XOR(i,k)
Thus,
XOR(i,k)=prefix[k+1]⊕prefix[i]
Efficient Solution using Prefix XOR
Instead of using nested loops to check all possible triplets (which would be too slow), we use a prefix XOR array:
- Compute
prefix[i]
, whereprefix[i] = arr[0] ⊕ arr[1] ⊕ ... ⊕ arr[i-1]
(XOR of all elements up toi-1
). - The key observation is that:
- If
prefix[i] == prefix[k+1]
, thenXOR(i, k) = 0
. - This means for any
j
betweeni+1
andk
, we can form a valid triplet. - because if XOR(i, k) == 0, then for every j between i and k we can write XOR(i, j-1)^XOR(j, k) == 0, it means XOR(i, j-1) == XOR(j, k)
- If
XOR(i,k)=0⇒XOR(i,j−1)=XOR(j,k)for all j∈(i,k]
This is the core property that allows us to count valid triplets in O(N²) instead of O(N³).
so, the solution is
public class Solution {
public int countTriplets(int[] arr) {
int n = arr.length;
int count = 0;
// Prefix XOR (cumulative XOR)
int[] prefix = new int[n + 1];
for (int i = 0; i < n; i++) {
prefix[i + 1] = prefix[i] ^ arr[i];
}
// Finding valid triplets
for (int i = 0; i < n; i++) {
for (int k = i + 1; k < n; k++) {
if (prefix[i] == prefix[k + 1]) {
count += (k - i); // j can be any index between i+1 to k
}
}
}
return count;
}
}