import os
import argparse
import pickle
import numpy as np
import torch

from utils import set_seed, log
from data import load_mnist, subsample_mnist, build_providers
from knn import flat_mnist
from poset import build_poset_from_node_dag
from sampling import make_sampler_kwargs, sample_linear_extension
from value import compute_value


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--method", type=str, default="PASV", choices=["SV", "WSV", "PSV", "PASV"])
    parser.add_argument(
        "--limit_case",
        type=str,
        default="none",
        choices=["none", "booster", "copier", "poisoner", "anchor"],
    )
    parser.add_argument("--nreps", type=int, default=10)
    parser.add_argument("--num_permutations", type=int, default=10000)
    parser.add_argument("--burnin_steps", type=int, default=100000, help="MCMC burn-in steps")
    parser.add_argument("--steps_between", type=int, default=10000, help="MCMC steps between samples (thinning)")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--data_root", type=str, default="./data")
    parser.add_argument("--per_class_train", type=int, default=10)
    parser.add_argument("--per_class_test", type=int, default=100)
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "cifar10"])
    parser.add_argument("--embed", type=str, default=None, choices=["pixels", "vit_b16", "resnet18", "resnet34", "resnet50"])

    parser.add_argument("--poison_prob", type=float, default=1.0)
    parser.add_argument("--booster_owner_count", type=int, default=50)
    parser.add_argument("--copier_owner_count", type=int, default=50)
    parser.add_argument("--poisoner_owner_count", type=int, default=50)

    parser.add_argument("--booster1_source", type=str, default="gan", choices=["none", "gan", "ddpm", "ddim", "fm"])
    parser.add_argument("--booster1_variant_owner", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--booster1_variant_anchor", type=str, default="aug", choices=["orig", "aug"])
    parser.add_argument("--booster1_samples_root_owner", type=str, default=None)
    parser.add_argument("--booster1_samples_root_anchor", type=str, default=None)

    parser.add_argument("--booster2_source", type=str, default="ddpm", choices=["none", "gan", "ddpm", "ddim", "fm"])
    parser.add_argument("--booster2_variant_owner", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--booster2_variant_anchor", type=str, default="aug", choices=["orig", "aug"])
    parser.add_argument("--booster2_samples_root_owner", type=str, default=None)
    parser.add_argument("--booster2_samples_root_anchor", type=str, default=None)

    parser.add_argument("--booster3_source", type=str, default="ddim", choices=["none", "gan", "ddpm", "ddim", "fm"])
    parser.add_argument("--booster3_variant_owner", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--booster3_variant_anchor", type=str, default="aug", choices=["orig", "aug"])
    parser.add_argument("--booster3_samples_root_owner", type=str, default=None)
    parser.add_argument("--booster3_samples_root_anchor", type=str, default=None)

    parser.add_argument("--booster4_source", type=str, default="fm", choices=["none", "gan", "ddpm", "ddim", "fm"])
    parser.add_argument("--booster4_variant_owner", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--booster4_variant_anchor", type=str, default="aug", choices=["orig", "aug"])
    parser.add_argument("--booster4_samples_root_owner", type=str, default=None)
    parser.add_argument("--booster4_samples_root_anchor", type=str, default=None)

    parser.add_argument("--k", type=int, default=20)
    parser.add_argument("--lam_base", type=float, default=2.0, help="Base b for priority weights; weights use b**rank")
    parser.add_argument(
        "--lam_exponents",
        type=str,
        default=None,
    )
    parser.add_argument("--mar_con", action="store_true", default=False)
    args = parser.parse_args()

    set_seed(args.seed)
    rng = np.random.RandomState(args.seed)

    # Set default embed based on dataset
    if args.embed is None:
        if args.dataset == "mnist":
            args.embed = "pixels"
        else:  # cifar10
            args.embed = "resnet18"

    # Load data
    if args.dataset == "mnist":
        X_orig_full, y_orig_full, X_test_full, y_test_full = load_mnist(args.data_root)
        X_owner, y_owner, X_test, y_test = subsample_mnist(
            X_orig_full, y_orig_full, X_test_full, y_test_full,
            per_class_train=args.per_class_train,
            per_class_test=args.per_class_test,
            rng=rng,
        )
        builder_fn = build_providers
        flatten_fn_eval = flat_mnist
    else:
        from data import load_cifar10 as _load_cifar10, subsample_cifar10 as _subsample_cifar10, build_providers_cifar as _build_providers_cifar
        X_orig_full, y_orig_full, X_test_full, y_test_full = _load_cifar10(args.data_root)
        X_owner, y_owner, X_test, y_test = _subsample_cifar10(
            X_orig_full, y_orig_full, X_test_full, y_test_full,
            per_class_train=args.per_class_train,
            per_class_test=args.per_class_test,
            rng=rng,
        )
        builder_fn = _build_providers_cifar
        if args.embed == "pixels":
            from knn import flat_mnist as _flat_pixels
            flatten_fn_eval = _flat_pixels
        else:
            from knn import build_flatten_fn as _build_flatten_fn
            device_name = "cuda" if torch.cuda.is_available() else "cpu"
            flatten_fn_eval = _build_flatten_fn(args.embed, device_name)

    nodes_base, providers_base = builder_fn(
        X_owner, y_owner,
        rng=rng,
        poison_prob=args.poison_prob,
        booster_owner_count=args.booster_owner_count,
        copier_owner_count=args.copier_owner_count,
        poisoner_owner_count=args.poisoner_owner_count,
        booster_source="none",
        booster_samples_root=None,
    )

    nodes_b1, _ = builder_fn(
        X_owner, y_owner,
        rng=rng,
        poison_prob=args.poison_prob,
        booster_owner_count=args.booster_owner_count,
        copier_owner_count=args.copier_owner_count,
        poisoner_owner_count=args.poisoner_owner_count,
        booster_source=args.booster1_source,
        booster_samples_root=None,
        booster_variant_owner=args.booster1_variant_owner,
        booster_variant_anchor=args.booster1_variant_anchor,
        booster_samples_root_owner=args.booster1_samples_root_owner,
        booster_samples_root_anchor=args.booster1_samples_root_anchor,
    )

    nodes_b2, _ = builder_fn(
        X_owner, y_owner,
        rng=rng,
        poison_prob=args.poison_prob,
        booster_owner_count=args.booster_owner_count,
        copier_owner_count=args.copier_owner_count,
        poisoner_owner_count=args.poisoner_owner_count,
        booster_source=args.booster2_source,
        booster_samples_root=None,
        booster_variant_owner=args.booster2_variant_owner,
        booster_variant_anchor=args.booster2_variant_anchor,
        booster_samples_root_owner=args.booster2_samples_root_owner,
        booster_samples_root_anchor=args.booster2_samples_root_anchor,
    )

    nodes_b3, _ = builder_fn(
        X_owner, y_owner,
        rng=rng,
        poison_prob=args.poison_prob,
        booster_owner_count=args.booster_owner_count,
        copier_owner_count=args.copier_owner_count,
        poisoner_owner_count=args.poisoner_owner_count,
        booster_source=args.booster3_source,
        booster_samples_root=None,
        booster_variant_owner=args.booster3_variant_owner,
        booster_variant_anchor=args.booster3_variant_anchor,
        booster_samples_root_owner=args.booster3_samples_root_owner,
        booster_samples_root_anchor=args.booster3_samples_root_anchor,
    )

    nodes_b4, _ = builder_fn(
        X_owner, y_owner,
        rng=rng,
        poison_prob=args.poison_prob,
        booster_owner_count=args.booster_owner_count,
        copier_owner_count=args.copier_owner_count,
        poisoner_owner_count=args.poisoner_owner_count,
        booster_source=args.booster4_source,
        booster_samples_root=None,
        booster_variant_owner=args.booster4_variant_owner,
        booster_variant_anchor=args.booster4_variant_anchor,
        booster_samples_root_owner=args.booster4_samples_root_owner,
        booster_samples_root_anchor=args.booster4_samples_root_anchor,
    )

    node_names = [
        "owner",
        "anchor",
        "copier_from_owner",
        "copier_from_anchor",
        "poisoner_from_owner",
        "poisoner_from_anchor",
        "booster1_from_owner",
        "booster1_from_anchor",
        "booster2_from_owner",
        "booster2_from_anchor",
        "booster3_from_owner",
        "booster3_from_anchor",
        "booster4_from_owner",
        "booster4_from_anchor",
    ]
    node_arrays = [
        (nodes_base["owner"]["X"], nodes_base["owner"]["y"]),
        (nodes_base["anchor"]["X"], nodes_base["anchor"]["y"]),
        (nodes_base["copier_from_owner"]["X"], nodes_base["copier_from_owner"]["y"]),
        (nodes_base["copier_from_anchor"]["X"], nodes_base["copier_from_anchor"]["y"]),
        (nodes_base["poisoner_from_owner"]["X"], nodes_base["poisoner_from_owner"]["y"]),
        (nodes_base["poisoner_from_anchor"]["X"], nodes_base["poisoner_from_anchor"]["y"]),
        (nodes_b1["booster_from_owner"]["X"], nodes_b1["booster_from_owner"]["y"]),
        (nodes_b1["booster_from_anchor"]["X"], nodes_b1["booster_from_anchor"]["y"]),
        (nodes_b2["booster_from_owner"]["X"], nodes_b2["booster_from_owner"]["y"]),
        (nodes_b2["booster_from_anchor"]["X"], nodes_b2["booster_from_anchor"]["y"]),
        (nodes_b3["booster_from_owner"]["X"], nodes_b3["booster_from_owner"]["y"]),
        (nodes_b3["booster_from_anchor"]["X"], nodes_b3["booster_from_anchor"]["y"]),
        (nodes_b4["booster_from_owner"]["X"], nodes_b4["booster_from_owner"]["y"]),
        (nodes_b4["booster_from_anchor"]["X"], nodes_b4["booster_from_anchor"]["y"]),
    ]

    X_train_all = np.concatenate([xy[0] for xy in node_arrays], axis=0)
    y_train_all = np.concatenate([xy[1] for xy in node_arrays], axis=0)
    n_all = X_train_all.shape[0]

    node_sizes = [xy[0].shape[0] for xy in node_arrays]
    node_ids = np.concatenate(
        [np.full(sz, nid, dtype=np.int64) for nid, sz in enumerate(node_sizes)],
        axis=0,
    )

    provider_names = ["owner", "anchor", "booster1", "booster2", "booster3", "booster4", "copier", "poisoner"]
    name_to_pid = {name: pid for pid, name in enumerate(provider_names)}

    node2provider = {
        "owner": "owner",
        "anchor": "anchor",
        "copier_from_owner": "copier",
        "copier_from_anchor": "copier",
        "poisoner_from_owner": "poisoner",
        "poisoner_from_anchor": "poisoner",
        "booster1_from_owner": "booster1",
        "booster1_from_anchor": "booster1",
        "booster2_from_owner": "booster2",
        "booster2_from_anchor": "booster2",
        "booster3_from_owner": "booster3",
        "booster3_from_anchor": "booster3",
        "booster4_from_owner": "booster4",
        "booster4_from_anchor": "booster4",
    }

    provider_ids = np.empty(n_all, dtype=np.int64)
    start = 0
    for nname, sz in zip(node_names, node_sizes):
        pid = name_to_pid[node2provider[nname]]
        provider_ids[start:start+sz] = pid
        start += sz

    owner_pid = name_to_pid["owner"]
    owner_mask = (provider_ids == owner_pid)
    owner_ids = np.where(owner_mask)[0]
    non_owner_ids = np.where(~owner_mask)[0]

    group_index_dict = {}
    for pname, pid in name_to_pid.items():
        idx = np.where(provider_ids == pid)[0]
        group_index_dict[pname] = idx

    P_full = build_poset_from_node_dag(node_ids, node_names)

    limit_case = str(args.limit_case)

    anchor_ids = group_index_dict.get("anchor", np.array([], dtype=np.int64))
    ua_node_names = [n for n in node_names if n.endswith("_from_anchor") and n != "anchor"]
    ua_ids_list = []
    for nn in ua_node_names:
        nid = node_names.index(nn)
        ua_ids_list.append(np.where(node_ids == nid)[0])
    U_a_ids = np.concatenate(ua_ids_list, axis=0) if len(ua_ids_list) > 0 else np.array([], dtype=np.int64)

    G_ids = None
    rest_ids = None
    P_for_sampler = P_full
    lam_uniform_for_sampler = None
    if limit_case == "booster":
        booster_ids_list = []
        for bname in ["booster1", "booster2", "booster3", "booster4"]:
            if bname in group_index_dict:
                booster_ids_list.append(np.asarray(group_index_dict[bname], dtype=np.int64))
        G_ids = np.concatenate(booster_ids_list, axis=0) if len(booster_ids_list) > 0 else np.array([], dtype=np.int64)
    elif limit_case == "copier":
        G_ids = np.asarray(group_index_dict.get("copier", np.array([], dtype=np.int64)), dtype=np.int64)
    elif limit_case == "poisoner":
        G_ids = np.asarray(group_index_dict.get("poisoner", np.array([], dtype=np.int64)), dtype=np.int64)

    if limit_case in {"booster", "copier", "poisoner"}:
        if G_ids is None or G_ids.size == 0:
            raise ValueError(f"limit_case={limit_case} requires non-empty G_ids")
        in_G = np.zeros(n_all, dtype=bool)
        in_G[np.asarray(G_ids, dtype=np.int64)] = True
        rest_ids = np.where(~in_G)[0].astype(np.int64)

        rest_set = set(int(x) for x in rest_ids.tolist())
        old_to_new = {int(old): new for new, old in enumerate(rest_ids.tolist())}
        preds_r = []
        succs_r = []
        for old in rest_ids.tolist():
            old = int(old)
            preds_old = P_full["preds"][old]
            succs_old = P_full["succs"][old]
            preds_r.append(set(old_to_new[int(p)] for p in preds_old if int(p) in rest_set))
            succs_r.append(set(old_to_new[int(s)] for s in succs_old if int(s) in rest_set))
        P_for_sampler = {"n": int(rest_ids.size), "preds": preds_r, "succs": succs_r}
        lam_uniform_for_sampler = np.ones(int(rest_ids.size), dtype=float)

    lam_uniform = np.ones(n_all, dtype=float)
    lam_priority = np.ones(n_all, dtype=float)
    if lam_uniform_for_sampler is None:
        lam_uniform_for_sampler = lam_uniform
    lam_base = float(args.lam_base)
    lam_exp_vec = None
    if args.lam_exponents is not None:
        try:
            lam_exp_vec = [float(x.strip()) for x in str(args.lam_exponents).split(",") if x.strip() != ""]
            if len(lam_exp_vec) != 5:
                print(f"Warning: --lam_exponents expects 5 comma-separated floats, got {len(lam_exp_vec)}. Falling back to default (1..5).")
                lam_exp_vec = None
        except Exception as e:
            print(f"Warning: failed to parse --lam_exponents='{args.lam_exponents}': {e}. Falling back to default (1..5).")
            lam_exp_vec = None
    if lam_exp_vec is None:
        lam_exp_owner = 1.0
        lam_exp_anchor = 2.0
        lam_exp_booster = 3.0
        lam_exp_copier = 4.0
        lam_exp_poisoner = 5.0
    else:
        lam_exp_owner, lam_exp_anchor, lam_exp_booster, lam_exp_copier, lam_exp_poisoner = lam_exp_vec
    for pid, name in enumerate(provider_names):
        idx = np.where(provider_ids == pid)[0]
        if name == "owner":
            lam_priority[idx] = lam_base ** float(lam_exp_owner)
        elif name == "anchor":
            lam_priority[idx] = lam_base ** float(lam_exp_anchor)
        elif name.startswith("booster"):
            lam_priority[idx] = lam_base ** float(lam_exp_booster)
        elif name == "copier":
            lam_priority[idx] = lam_base ** float(lam_exp_copier)
        elif name == "poisoner":
            lam_priority[idx] = lam_base ** float(lam_exp_poisoner)

    print("X_train_all:", X_train_all.shape)
    print("y_train_all:", y_train_all.shape)
    print("X_test:", X_test.shape)
    print("y_test:", y_test.shape)
    print("n_all:", n_all)
    print("provider_names:", provider_names)
    print("group_index_dict keys:", list(group_index_dict.keys()))

    method = args.method
    print("\n" + "=" * 50)
    print(f"Running method = {method}")
    if limit_case != "none":
        print(f"[limit_case enabled] {limit_case}")

    nreps = int(args.nreps)
    group_sums_reps = np.zeros((nreps, len(provider_names)), dtype=float)
    num_perms_reps = np.zeros(nreps, dtype=int)
    phi_individual_reps = np.zeros((nreps, n_all), dtype=float)
    elapsed_seconds_reps = np.zeros(nreps, dtype=float)
    mean_rank_reps = np.zeros((nreps, len(provider_names)), dtype=float)
    
    marginal_contribution = None
    if args.mar_con:
        marginal_contribution = {
            "coalition_sizes_reps": [],   
            "marginal_contributions_reps": [],
            "providers_reps": [],
        }

    booster_ids_for_init = None
    if limit_case == "none":
        booster_ids_list = []
        for bname in ["booster1", "booster2", "booster3", "booster4"]:
            if bname in group_index_dict:
                booster_ids_list.append(np.asarray(group_index_dict[bname], dtype=np.int64))
        if len(booster_ids_list) > 0:
            booster_ids_for_init = np.concatenate(booster_ids_list, axis=0)

    for rep in range(1, nreps + 1):
        rng_rep = np.random.RandomState(rep)
        sampler_kwargs = make_sampler_kwargs(
            method=method,
            rng=rng_rep,
            n_all=n_all,
            owner_ids=owner_ids,
            non_owner_ids=non_owner_ids,
            P_full=P_for_sampler,
            lam_uniform=lam_uniform_for_sampler,
            lam_priority=lam_priority,
            limit_case=limit_case,
            G_ids=G_ids,
            anchor_ids=anchor_ids,
            U_a_ids=U_a_ids,
            rest_ids=rest_ids,
            steps_between=args.steps_between,
            burnin_steps=args.burnin_steps,
        )
        res = compute_value(
            method=method,
            sampler_kwargs=sampler_kwargs,
            X_train_all=X_train_all,
            y_train_all=y_train_all,
            X_eval=X_test,
            y_eval=y_test,
            flatten_fn=flatten_fn_eval,
            group_index_dict=group_index_dict,
            k=args.k,
            num_permutations=args.num_permutations,
            rng=rng_rep,
            record_marginal_contrib=args.mar_con,
            provider_ids=provider_ids if args.mar_con else None,
        )
        num_perms_reps[rep - 1] = res["num_permutations"]
        elapsed_seconds_reps[rep - 1] = res["elapsed_seconds"]
        phi_individual_reps[rep - 1] = res["phi_individual"]
        phi_group_sum = res["phi_group_sum"]
        for j, pname in enumerate(provider_names):
            group_sums_reps[rep - 1, j] = phi_group_sum[pname]
        
        if args.mar_con and "marginal_contribution" in res:
            mc_data = res["marginal_contribution"]
            provider_name_map = {pid: provider_names[pid] for pid in range(len(provider_names))}
            provider_names_list = [provider_name_map[pid] for pid in mc_data["providers"]]
            
            marginal_contribution["coalition_sizes_reps"].append(mc_data["coalition_sizes"])
            marginal_contribution["marginal_contributions_reps"].append(mc_data["marginal_contributions"])
            marginal_contribution["providers_reps"].append(provider_names_list)

        n_perms_for_rank = res["num_permutations"]
        perms_list = []
        for _ in range(n_perms_for_rank):
            perm = sample_linear_extension(method, rng=rng_rep, **sampler_kwargs)
            perms_list.append(np.asarray(perm, dtype=np.int64))
        perms_array = np.stack(perms_list, axis=0) 
        
        M, n = perms_array.shape
        for j, pname in enumerate(provider_names):
            prov_indices = group_index_dict[pname]
            if prov_indices.size == 0:
                mean_rank_reps[rep - 1, j] = np.nan
                continue
            ranks = np.empty(M, dtype=float)
            for m in range(M):
                perm = perms_array[m]
                pos = np.empty(n, dtype=np.int64)
                pos[perm] = np.arange(n, dtype=np.int64)
                ranks[m] = float(np.mean(pos[prov_indices] + 1)) 
            mean_rank_reps[rep - 1, j] = float(np.mean(ranks))

        print(f"[rep {rep}/{nreps}] num_permutations used: {res['num_permutations']}")
        print(f"[rep {rep}/{nreps}] group-level Shapley sums:")
        for g in provider_names:
            v = phi_group_sum[g]
            print(f"  {g:10s}: {v:.6f}")
        print(f"[rep {rep}/{nreps}] mean ranks:")
        for g in provider_names:
            idx = provider_names.index(g)
            mr = mean_rank_reps[rep - 1, idx]
            print(f"  {g:10s}: {mr:.2f}")

    ds_name = args.dataset
    save_root = os.path.join("save", ds_name)
    os.makedirs(save_root, exist_ok=True)
    n_per_provider = 10 * int(args.per_class_train)

    if method == "PASV" and limit_case != "none":
        limit_to_provider = {
            "anchor": "anchor",
            "booster": "booster",
            "copier": "copier",
            "poisoner": "poisoner",
        }
        prov = limit_to_provider.get(limit_case)
        out_pkl_path = os.path.join(save_root, f"result_PASV_limit_{prov}_{n_per_provider}.pkl")
    else:
        if method == "PASV":
            if args.lam_exponents is not None and lam_exp_vec is not None:
                exp_str = "".join([f"{x:g}" for x in [lam_exp_owner, lam_exp_anchor, lam_exp_booster, lam_exp_copier, lam_exp_poisoner]])
                lam_suffix = f"_b{args.lam_base:g}_exp{exp_str}"
            else:
                lam_suffix = f"_{args.lam_base:g}"
        else:
            lam_suffix = ""
        out_pkl_path = os.path.join(save_root, f"result_{method}_{n_per_provider}{lam_suffix}.pkl")
    payload = {
        "method": method,
        "limit_case": limit_case,
        "provider_names": provider_names,
        "nreps": nreps,
        "num_permutations_per_rep": num_perms_reps,
        "elapsed_seconds_per_rep": elapsed_seconds_reps,
        "group_sums_reps": group_sums_reps,
        "phi_individual_reps": phi_individual_reps,
        "mean_rank_reps": mean_rank_reps,
        "poison_prob": args.poison_prob,
        "booster_owner_count": args.booster_owner_count,
        "copier_owner_count": args.copier_owner_count,
        "poisoner_owner_count": args.poisoner_owner_count,
        "k": args.k,
        "booster1_source": args.booster1_source,
        "booster1_variant_owner": args.booster1_variant_owner,
        "booster1_variant_anchor": args.booster1_variant_anchor,
        "booster2_source": args.booster2_source,
        "booster2_variant_owner": args.booster2_variant_owner,
        "booster2_variant_anchor": args.booster2_variant_anchor,
        "booster3_source": args.booster3_source,
        "booster3_variant_owner": args.booster3_variant_owner,
        "booster3_variant_anchor": args.booster3_variant_anchor,
        "booster4_source": args.booster4_source,
        "booster4_variant_owner": args.booster4_variant_owner,
        "booster4_variant_anchor": args.booster4_variant_anchor,
        "lam_base": args.lam_base,
        "lam_exponents": None if lam_exp_vec is None else [lam_exp_owner, lam_exp_anchor, lam_exp_booster, lam_exp_copier, lam_exp_poisoner],
    }
    
    if args.mar_con and marginal_contribution is not None:
        payload["marginal_contribution"] = marginal_contribution
    
    with open(out_pkl_path, "wb") as f:
        pickle.dump(payload, f)
    print(f"\nSaved results (all reps) to: {out_pkl_path}")


if __name__ == "__main__":
    main()


