ゴリ押しで解けないこともないけど、Editorialはスマートな解法だね。
https://yukicoder.me/problems/no/1552
問題
整数列Xに対し、f(X)=(max(X)-min(X))*sum(X)とする。
整数N,Mが与えられる。
Xとして、N要素の数列で各値が1~Mであるものは、M^N通りある。
これらすべてのXに対するf(X)の総和を求めよ。
解法
X’を、Xの各要素を(M+1)から引いたものとする。
この場合max(X)-min(X)=max(X')-min(X')である。
また、sum(X)+sum(X')=N*(M+1)となる。
Xに対するf(X)の総和と、X'に対するf(X')の総和は当然等しいので、sum(X)の部分は、各要素がすべて平均値(M+1)/2であると考えてよい。
よってあとはmax(X)-min(X)が一致するものが何通りあるかという問題になる。
最小値がL、最大値がRとなる数列の個数は、最小値がL+1以上、最大値がR-1以下のケースを除くことを考えると(R-L+1)^N-2*(R-L)^N+(R-L-1)^Nとなる。
そこで(R-L)が一致するものをまとめて数え上げればよい。
ll N,M; const ll mo=998244353; ll modpow(ll a, ll n = mo-2) { ll r=1;a%=mo; while(n) r=r*((n%2)?a:1)%mo,a=a*a%mo,n>>=1; return r; } void solve() { int i,j,k,l,r,x,y; string s; cin>>N>>M; ll ret=0; for(int d=1;d<=M-1;d++) { ll pat=(modpow(d+1,N)-2*modpow(d,N)+modpow(d-1,N)+2*mo)%mo; (ret+=1LL*(M-d)*d%mo*pat)%=mo; } ret=ret*(N%mo)%mo*(M+1)%mo*modpow(2)%mo; cout<<ret<<endl; }
まとめ
sum(X)の計算をサボれるのいいな。