# SPOJ COT2 Count on a tree II

### 题目大意

n<=40000 m<=100000

### Solution

##### 代码
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
using namespace std;
struct Que{
int l,r,id,L,f,u,v;
}q[200010];
int fa[100010][20],ti,in[100010],out[100010],x[100010],a[100010],b[100010],cnt[100010],siz[100010],Ans[100010],pos[100010],key[100010],deep[100010];
int ans;
bool cmp(Que x,Que y)
{
if (pos[x.l]!=pos[y.l]) return pos[x.l]<pos[y.l];
if (pos[x.l]&1)
return x.r<y.r;
else return x.r>y.r;
}
{
vet[num]=v;
}
void dfs(int u)
{
in[u]=++ti;
x[ti]=a[u],key[ti]=u;
{
int v=vet[i];
if (v!=fa[u][0])
{
fa[v][0]=u;
deep[v]=deep[u]+1;
dfs(v);
}
}
out[u]=++ti;
x[ti]=a[u],key[ti]=u;
}
int lca(int u,int v)
{
if (deep[u]<deep[v]) swap(u,v);
for (int i=16;i>=0;i--)
if (fa[u][i]&&deep[fa[u][i]]>=deep[v]) u=fa[u][i];
if (u==v) return u;
for (int i=16;i>=0;i--)
if (fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
{
siz[key[X]]++;
if (siz[key[X]]==2)
{
cnt[x[X]]--;
if (cnt[x[X]]==0) ans--;
}
else
{
cnt[x[X]]++;
if (cnt[x[X]]==1) ans++;
}
}
void del(int X)
{
siz[key[X]]--;
if (siz[key[X]]==1)
{
cnt[x[X]]++;
if (cnt[x[X]]==1) ans++;
}
else
{
cnt[x[X]]--;
if (cnt[x[X]]==0) ans--;
}
}
int main()
{
int n,m;
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&a[i]),b[i]=a[i];
sort(b+1,b+1+n);
int n1=unique(b+1,b+1+n)-b-1;
for (int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+1+n,a[i])-b;
for (int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
}
deep[1]=1;
dfs(1);
for (int i=1;i<=16;i++)
for (int j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
for (int i=1;i<=m;i++)
{
int u,v;
scanf("%d%d",&u,&v);
if (in[u]>in[v]) swap(u,v);
q[i].u=u,q[i].v=v,q[i].id=i,q[i].l=in[u],q[i].r=in[v],q[i].L=lca(u,v);
if (q[i].L==u||q[i].L==v) q[i].f=1;
else q[i].f=0;
}
//  for (int i=1;i<=m;i++)
//      printf("%d %d %dn",q[i].l,q[i].r,q[i].id);
int xx=sqrt(ti);
for (int i=1;i<=ti;i++) pos[i]=(i-1)/xx+1;
sort(q+1,q+1+m,cmp);
ans=0;
int l=1,r=0;
memset(cnt,0,sizeof(cnt));
memset(siz,0,sizeof(siz));
for (int i=1;i<=m;i++)
{
while (l<q[i].l) del(l),l++;
while (r>q[i].r) del(r),r--;
if (q[i].f) Ans[q[i].id]=ans;
else
{
if (cnt[a[q[i].L]]==0)
{
Ans[q[i].id]=ans+1;
if (cnt[a[q[i].u]]==0&&a[q[i].u]!=a[q[i].L]) Ans[q[i].id]++;
}
else
{
Ans[q[i].id]=ans;
if (cnt[a[q[i].u]]==0) Ans[q[i].id]++;
}
}
}
for (int i=1;i<=m;i++)
printf("%dn",Ans[i]);
return 0;
}


