A 题 #
fn solve() {
input! {
n: usz,
a: [usz; n]
}
let mut cur = 0;
for i in 0..n {
cur += 1;
if a[i] == cur {
cur += 1;
}
}
println!("{}", cur);
}
fn main() {
input! {
t: usize
}
for _ in 0..t {
solve();
}
}B 题 #
对于每一个位贪心的求取不到该二进制位时,剩余所有集合的最大并集
fn solve() {
input! {
n: usize,
}
let mut mask_k = vec![];
let mut hm = HashMap::new();
let mut cnt: u64 = 0;
for _ in 0..n {
input! {
x: usize,
p: [usize; x]
}
let mut mask = 0;
for item in p {
if !hm.contains_key(&item) {
hm.insert(item, cnt);
cnt += 1;
}
mask |= 1 << hm[&item];
}
mask_k.push(mask);
}
let mut res = 0;
for i in 0..cnt {
let mut cur: u64 = 0;
for &mask in &mask_k {
if ((mask >> i) & 1) == 0 {
cur |= mask;
}
}
res = res.max(cur.count_ones());
}
println!("{}", res);
}
fn main() {
input! {
t: usize
}
for _ in 0..t {
solve();
}
}C 题 #
fn solve() {
input! {
n: usize,
cards: [i64; n]
}
// 0 select odd
// 1 skip odd
// 2 select even
// 3 skip even
let (mut a,mut b,mut c) = (cards[0], cards[0].max(0), i64::MIN);
for i in 1..n {
let na = a.max(c) + cards[i];
let nb = a.max(c) + cards[i].max(0);
let nc = b.max(c);
a = na;
b = nb;
c = nc;
}
println!("{}", a.max(b).max(c));
}
fn main() {
input! {
t: usize
}
dbg!(t);
for _ in 0..t {
solve();
}
}D 题 #
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 200004;
int n;
struct Node {
ll size, val, dp1, dp2;
};
map<Node *, vector<Node *>> e;
Node ns[maxn];
void dfs1(Node *u, Node *f) {
u->size = 1;
for (auto v: e[u]) {
if (v != f) {
dfs1(v, u);
u->size += v->size;
u->dp1 += v->dp1 + v->size * (u->val ^ v->val);
}
}
}
void dfs2(Node *u, Node *f) {
for (auto &v: e[u]) {
if (v != f) {
ll w = u->dp2 + u->dp1 - v->dp1 - v->size * (u->val ^ v->val);
v->dp2 = w + (n - v->size) * (u->val ^ v->val);
dfs2(v, u);
}
}
}
void solve() {
cin >> n;
e.clear();
for (int i = 1; i <= n; i++) {
cin >> ns[i].val;
ns[i].dp1 = ns[i].dp2 = ns[i].size = 0;
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
e[ns + u].push_back(ns + v);
e[ns + v].push_back(ns + u);
}
dfs1(ns + 1, nullptr);
dfs2(ns + 1, nullptr);
for (int i = 1; i <= n; i++) {
printf("%lld ", ns[i].dp1 + ns[i].dp2);
}
putchar('\n');
}
int main() {
cin.tie(0)->sync_with_stdio(0);
int T;
cin >> T;
while (T--) {
solve();
}
}