树形DP。很好的一道题目,这里认真写写。
题目很长,大意就是有一个树形图,出发点是树根,目标(即终点)以等概率的形式分布于每个树枝末端,要求的是在未知哪根树枝为终点的情况,最优搜索方式下的期望总搜索长度。另外这里还有一点附加的是,每个分支点可能有提示,即到达该树枝时可以立即知道目标是否在该子树下。
先不考虑存在提示的情况。对每棵子树的搜索可能有两种结果,成功或失败。若成功则要分别计算下一层每个子树的成功或失败情况,并计算出最优搜索下的搜索长度,这里记为ai;若失败则简单将所有子树的搜索总长累加即可,这里记为bi。另外,由于子树搜索成功是要计算概率的,这和子树上包含的叶子总数有关,因此需要维护一个子树所含叶子总数ni。由此便可以开始分析ai的计算了。
ai的计算也可以分为两部分,首先由概率公式可以得到,每棵子树j搜索成功的这部分概率是不受子树间先后顺序影响的,其对最优搜索期望总长的贡献总是(aj+1)*nj/ni。那么,这里最关键的就是搜索失败时的处理情况。这是本题的难点。
当已知某棵子树搜索失败时,后续子树搜索失败的概率会变低,若设剩余子树所含的总叶子数为nl,具体期望值公式为(bj+2)*(nl-nj)/ni,也即终点应在剩余的nl-nj个叶子中的某个上。这里是确定最优搜索方式的关键,我想了很久。若设l为剩余子树编号,尝试了每次根据(sum(bl+2)-bj-2)*(nl-nj)的最大值来贪心选择下一次搜索的子树,这里的思想是每次都使剩余期望值降低最多,能过test cases,但是submit后WA。这里思路一下子卡死了。便看了discuss,恰巧是magic_pig拯救了我(他是之前情书作者kinfkong的中大校友),他提到要按(bj+2)/nj排序。尝试了一下果然AC了。。。
晚上仔细想了想这样能求解的原因。观察到其形式很类似轮换不等式,朝这个方向推了一下之后便明白了:若交换搜索中次序相邻子树的j和k,是不影响之前和之后其他子树的期望值的。同时,每个子树受其他子树影响的量是bj+2,影响其他子树期望的量是nj,由此便能通过对(bj+2)/nj进行排序来确定最优搜索顺序了。实在是太过精巧的思路。
到这里为止,本题的所有trick基本上排除完了。bi的计算则是简单的sum(bj+2),若该子树存在提示,取bi=0即可。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; int N,w[1000],hd[1000],el,fl,n[1000]; double a[1000],b[1000]; struct E { int v,p; }e[999]; struct F { double bv,e; int nv; bool operator < (const F &p) const { return e<p.e; } }f[1000]; inline void adde(int u,int v) { e[el].v=v; e[el].p=hd[u]; hd[u]=el++; } void calc(int u) { if(hd[u]==-1) { n[u]=1; a[u]=b[u]=0; return; } int i,l,v,nl; n[u]=0; a[u]=b[u]=0; for(l=hd[u];l!=-1;l=e[l].p) { v=e[l].v; calc(v); n[u]+=n[v]; } for(l=hd[u],fl=0;l!=-1;l=e[l].p) { v=e[l].v; f[fl].bv=b[v]+2; f[fl].nv=n[v]; f[fl].e=(b[v]+2)/n[v]; fl++; a[u]+=(a[v]+1)*n[v]/n[u]; if(!w[u]) b[u]+=b[v]+2; } sort(f,f+fl); for(i=0,nl=n[u];i<fl-1;i++) { a[u]+=f[i].bv*(nl-f[i].nv)/n[u]; nl-=f[i].nv; } } int main() { int i,j,k,l; i=1;j=2; double e,f; char s[2]; while(~scanf("%d",&N)&&N) { el=0; memset(hd,-1,sizeof hd); scanf("%*d%*s"); for(i=1;i<N;i++) { scanf("%d%s",&j,s); adde(j-1,i); if(s[0]=='N') w[i]=0; else w[i]=1; } calc(0); printf("%.4f\n",a[0]); } return 0; } |
这样的题目,完全弄懂之后总是有种难以言喻的快感… 马上又要开始做项目了,还能像现在这样无忧无虑做题的时间也不多了。