fork download
  1. #include <iostream>
  2. #include <vector>
  3.  
  4. using namespace std;
  5.  
  6. const long long MOD = 998244353;
  7.  
  8. long long power(long long base, long long exp) {
  9. long long res = 1;
  10. base %= MOD;
  11. while (exp > 0) {
  12. if (exp % 2 == 1) res = (res * base) % MOD;
  13. base = (base * base) % MOD;
  14. exp /= 2;
  15. }
  16. return res;
  17. }
  18.  
  19. long long modInverse(long long n) {
  20. return power(n, MOD - 2);
  21. }
  22.  
  23. long long total_sum = 0;
  24.  
  25. int dfs(int u, int p, const vector<vector<int>>& adj, int N) {
  26. int sz = 1;
  27. for (int v : adj[u]) {
  28. if (v != p) {
  29. int sub_sz = dfs(v, u, adj, N);
  30. sz += sub_sz;
  31. long long cur_edge_contrib = 1LL * sub_sz * (N - sub_sz) % MOD;
  32. total_sum = (total_sum + cur_edge_contrib) % MOD;
  33. }
  34. }
  35. return sz;
  36. }
  37.  
  38. void solve() {
  39. int N;
  40. long long K;
  41. cin >> N >> K;
  42. vector<vector<int>> adj(N + 1);
  43. for (int i = 0; i < N - 1; ++i) {
  44. int u, v;
  45. cin >> u >> v;
  46. adj[u].push_back(v);
  47. adj[v].push_back(u);
  48. }
  49.  
  50. total_sum = 0;
  51. dfs(1, 0, adj, N);
  52.  
  53. long long S = (N + K - 1) % MOD;
  54. long long num = S;
  55. num = (num * ((S + 1) % MOD)) % MOD;
  56. num = (num * ((S + 2) % MOD)) % MOD;
  57.  
  58. long long den = (N - 1) % MOD;
  59. den = (den * (N % MOD)) % MOD;
  60. den = (den * ((N + 1) % MOD)) % MOD;
  61.  
  62. long long mult = (num * modInverse(den)) % MOD;
  63.  
  64. long long ans = (total_sum * mult) % MOD;
  65. cout << ans << "\n";
  66. }
  67.  
  68. int main() {
  69. ios_base::sync_with_stdio(false);
  70. cin.tie(NULL);
  71. int T;
  72. if (cin >> T) {
  73. while (T--) {
  74. solve();
  75. }
  76. }
  77. return 0;
  78. }
Success #stdin #stdout 0s 5324KB
stdin
3
4 2
1 2
1 3
1 4
4 5
1 2
2 3
3 4
6 13
1 2
1 3
2 4
2 5
3 6
stdout
499122208
120
713032723