import os
import argparse
import pickle
import numpy as np

from data import load_income_data
from model import train_mlp_classifier
from poset import build_income_poset, build_ordered_partition_poset
from utils import (
    make_empirical_global_utility_fn, 
    lambda_block_wise,
    are_comparable,
    is_predecessor,
    is_successor,
    swap_delay,
    swap_advance,
    compute_utility_trajectory,
)
from value import estimate_poset_shapley_with_rank


# Block definitions
upstream_block = ["age", "sex", "race", "native-country"]
middle_block = ["marital-status", "education", "occupation", "workclass"]
downstream_block = ["relationship", "capital-gain", "capital-loss", "hours-per-week"]


def parse_exponents(exp_str: str) -> list:
    return [float(x.strip()) for x in exp_str.split(",") if x.strip() != ""]


def main():
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--num_permutations", type=int, default=3000)
    parser.add_argument("--burnin_steps", type=int, default=10000)
    parser.add_argument("--steps_between", type=int, default=1000)
    parser.add_argument("--nreps", type=int, default=10)
    parser.add_argument("--random_state", type=int, default=42)
    
    parser.add_argument("--n_eval", type=int, default=1000)
    parser.add_argument("--k", type=int, default=100)
    parser.add_argument("--k_nn", type=int, default=500)
    
    parser.add_argument("--weight", type=str, default="uniform", choices=["sweep", "uniform"])
    parser.add_argument("--poset", type=str, default="general", choices=["general", "op"])
    parser.add_argument("--block", type=str, default="upstream", choices=["upstream", "middle", "downstream"])
    parser.add_argument("--exponents", type=str, default="0,1,0,0")
    parser.add_argument("--base", type=float, default=None)
    parser.add_argument("--base_array", type=str, default=None)
    
    parser.add_argument("--save_dir", type=str, default=None)
    parser.add_argument("--verbose_every", type=int, default=500)
    parser.add_argument("--swap_feature", type=str, default="capital-gain")
    
    args = parser.parse_args()
    
    if args.weight == "sweep" and args.poset == "op":
        raise ValueError("weight=sweep and poset=op combination is not allowed")

    exponents = None
    if args.weight == "sweep":
        exponents = parse_exponents(args.exponents)
        if len(exponents) != 4:
            raise ValueError(f"exponents must have 4 values, got {len(exponents)}")
    
    base_array = None
    if args.weight == "sweep":
        if args.base is not None:
            base_array = [args.base]
        elif args.base_array is not None:
            base_array = [float(x.strip()) for x in args.base_array.split(",") if x.strip() != ""]
        else:
            base_array = [2**i for i in range(-8, 9)]
    
    block_features = None
    if args.weight == "sweep":
        if args.block == "upstream":
            block_features = upstream_block
        elif args.block == "middle":
            block_features = middle_block
        else:
            block_features = downstream_block
    
    if args.save_dir is None:
        script_dir = os.path.dirname(os.path.abspath(__file__))
        save_dir = os.path.join(script_dir, "save")
    else:
        save_dir = args.save_dir
    os.makedirs(save_dir, exist_ok=True)
    
    print(f"  Permutations: {args.num_permutations}")
    print(f"  Burnin: {args.burnin_steps}, Steps between: {args.steps_between}")
    print(f"  Reps: {args.nreps}")
    print()
    
    print("Loading data...")
    X_train, X_test, y_train, y_test = load_income_data(random_state=args.random_state)
    
    print("Training MLP classifier...")
    model, acc, auc = train_mlp_classifier(
        X_train, y_train, X_test, y_test, 
        random_state=args.random_state, 
        verbose=False
    )
    print(f"  Test Accuracy: {acc:.4f}, ROC-AUC: {auc:.4f}")
    
    print("Building poset...")
    node_names = list(X_train.columns)
    feat_names = node_names
    feat2idx = {f: i for i, f in enumerate(node_names)}
    
    if args.poset == "op":
        # Build ordered partition (blocked) poset
        P = build_ordered_partition_poset(
            node_names=node_names,
            upstream_block=upstream_block,
            downstream_block=downstream_block
        )
        print(f"  Ordered partition poset: {len(node_names)} nodes, {len([e for edges in P['succs'] for e in edges])} edges")
        print(f"    Upstream block: {upstream_block}")
        print(f"    Downstream block: {downstream_block}")
    else:
        P = build_income_poset(node_names)
        print(f"  General DAG poset: {len(node_names)} nodes, {len([e for edges in P['succs'] for e in edges])} edges")
    
    print("Building empirical utility function...")
    y_eval = y_test.to_numpy() if hasattr(y_test, "to_numpy") else np.asarray(y_test)
    v_emp = make_empirical_global_utility_fn(
        model=model,
        X_train_raw=X_train,
        X_eval_raw=X_test,
        y_eval=y_eval,
        feat_names=feat_names,
        n_eval=args.n_eval,
        seed=args.random_state,
        K=args.k,
        k_nn=args.k_nn,
    )
    
    if args.weight == "uniform":
        print("Running experiment with uniform weights...")
        print(f"  Weight: uniform (all weights = 1)")
        print(f"  Poset: {args.poset}")
        print(f"  Permutations: {args.num_permutations} Burnin: {args.burnin_steps} Steps_between: {args.steps_between}")
        print(f"  Reps: {args.nreps}\n")
        
        lam = np.ones(len(feat_names), dtype=np.float64)
        
        phi_reps = np.zeros((args.nreps, len(feat_names)), dtype=np.float64)
        rank_reps = np.zeros((args.nreps, len(feat_names)), dtype=np.float64)
        
        track_extensions = (args.poset == "general")
        if track_extensions:
            if args.swap_feature not in feat_names:
                raise ValueError(f"swap_feature '{args.swap_feature}' not in feat_names: {feat_names}")
            
            linear_extensions_reps = []
            utility_values_reps = []
            linear_extensions_advance_reps = []
            utility_values_advance_reps = []
            linear_extensions_delay_reps = []
            utility_values_delay_reps = []
            
        else:
            linear_extensions_reps = None
            utility_values_reps = None
            linear_extensions_advance_reps = None
            utility_values_advance_reps = None
            linear_extensions_delay_reps = None
            utility_values_delay_reps = None
        
        for r in range(args.nreps):
            seed_r = args.random_state + r + 1
            
            out_r = estimate_poset_shapley_with_rank(
                P=P,
                lamb=lam,
                feat_names=feat_names,
                v_global_fn=v_emp,
                M=args.num_permutations,
                seed=seed_r,
                laziness=0.1,
                burnin_steps=args.burnin_steps,
                steps_between=args.steps_between,
                verbose_every=args.verbose_every,
            )
            
            phi_reps[r] = out_r["phi"]
            rank_reps[r] = out_r["mean_ranks"]
            
            if track_extensions:
                linear_extensions_reps.append(out_r["linear_extensions"])
                utility_values_reps.append(out_r["utility_values"])
                
                linear_extensions_advance = []
                utility_values_advance = []
                linear_extensions_delay = []
                utility_values_delay = []
                
                utility_cache = {}
                def v_cached(S_list):
                    key = frozenset(S_list)
                    if key in utility_cache:
                        return utility_cache[key]
                    val = float(v_emp(S_list))
                    utility_cache[key] = val
                    return val
                
                for perm in out_r["linear_extensions"]:
                    perm_advance = swap_advance(perm, args.swap_feature, P, feat2idx)
                    linear_extensions_advance.append(perm_advance)
                    utility_values_advance.append(compute_utility_trajectory(perm_advance, v_cached, feat_names))
                    
                    perm_delay = swap_delay(perm, args.swap_feature, P, feat2idx)
                    linear_extensions_delay.append(perm_delay)
                    utility_values_delay.append(compute_utility_trajectory(perm_delay, v_cached, feat_names))
                
                linear_extensions_advance_reps.append(linear_extensions_advance)
                utility_values_advance_reps.append(utility_values_advance)
                linear_extensions_delay_reps.append(linear_extensions_delay)
                utility_values_delay_reps.append(utility_values_delay)
            
        
        phi_reps_by_base = None
        rank_reps_by_base = None
        phi_reps_uniform = phi_reps
        rank_reps_uniform = rank_reps
        linear_extensions_uniform = linear_extensions_reps if track_extensions else None
        utility_values_uniform = utility_values_reps if track_extensions else None
        linear_extensions_advance_uniform = linear_extensions_advance_reps if track_extensions else None
        utility_values_advance_uniform = utility_values_advance_reps if track_extensions else None
        linear_extensions_delay_uniform = linear_extensions_delay_reps if track_extensions else None
        utility_values_delay_uniform = utility_values_delay_reps if track_extensions else None
        
    else:
        print("Iterating over base values...")
        total_bases = len(base_array)
        print(f"  Total base values: {total_bases}\n")
        
        phi_reps_by_base = {}
        rank_reps_by_base = {}
        
        for base_idx, base in enumerate(base_array):
            lam = lambda_block_wise(
                base, block_features, exponents, node_names, feat2idx
            )

            # Run experiment with replications
            print(f"[Running experiment with replications]")
            print(f"  Permutations={args.num_permutations} Burnin={args.burnin_steps} Steps_between={args.steps_between}")
            print(f"  Reps={args.nreps}\n")
            
            phi_reps = np.zeros((args.nreps, len(feat_names)), dtype=np.float64)
            rank_reps = np.zeros((args.nreps, len(feat_names)), dtype=np.float64)
            
            for r in range(args.nreps):
                seed_r = args.random_state + r + 1  # Different seed for each rep (only MCMC randomness)
                print(f"[rep {r+1}/{args.nreps}] Starting...")
                
                out_r = estimate_poset_shapley_with_rank(
                    P=P,
                    lamb=lam,
                    feat_names=feat_names,
                    v_global_fn=v_emp,
                    M=args.num_permutations,
                    seed=seed_r,
                    laziness=0.1,
                    burnin_steps=args.burnin_steps,
                    steps_between=args.steps_between,
                    verbose_every=args.verbose_every,
                )
                
                phi_reps[r] = out_r["phi"]
                rank_reps[r] = out_r["mean_ranks"]
                
            phi_reps_by_base[base] = phi_reps
            rank_reps_by_base[base] = rank_reps
        
        phi_reps_uniform = None
        rank_reps_uniform = None
    
    if args.weight == "uniform":
        fname = f"result_exp2_FA_uniform_{args.poset}.pkl"
        outpath = os.path.join(save_dir, fname)
        
        payload = {
            "phi_reps": phi_reps_uniform,
            "rank_reps": rank_reps_uniform,
            "feat_names": feat_names,
            "config": {
                "weight": args.weight,
                "poset": args.poset,
                "num_permutations": args.num_permutations,
                "burnin_steps": args.burnin_steps,
                "steps_between": args.steps_between,
                "nreps": args.nreps,
                "random_state": args.random_state,
                "n_eval": args.n_eval,
                "k": args.k,
                "k_nn": args.k_nn,
            }
        }
        
        if args.poset == "general":
            payload["linear_extensions"] = linear_extensions_uniform
            payload["utility_values"] = utility_values_uniform
            payload["linear_extensions_advance"] = linear_extensions_advance_uniform
            payload["utility_values_advance"] = utility_values_advance_uniform
            payload["linear_extensions_delay"] = linear_extensions_delay_uniform
            payload["utility_values_delay"] = utility_values_delay_uniform
            payload["swap_feature"] = args.swap_feature
        
    else:
        exp_str = f"{args.block}_exp{''.join([str(int(e)) if e == int(e) else f'{e:.1f}' for e in exponents])}"
        fname = f"result_exp2_FA_{exp_str}.pkl"
        outpath = os.path.join(save_dir, fname)
    
        payload = {
            "phi_reps_by_base": phi_reps_by_base,
            "rank_reps_by_base": rank_reps_by_base,
            "base_array": base_array,
            "feat_names": feat_names,
            "config": {
                "weight": args.weight,
                "poset": args.poset,
                "selected_block": args.block,
                "block_features": block_features,
                "exponents": exponents,
                "base_array": base_array,
                "num_permutations": args.num_permutations,
                "burnin_steps": args.burnin_steps,
                "steps_between": args.steps_between,
                "nreps": args.nreps,
                "random_state": args.random_state,
                "n_eval": args.n_eval,
                "k": args.k,
                "k_nn": args.k_nn,
            }
        }
        
    with open(outpath, "wb") as f:
        pickle.dump(payload, f)


if __name__ == "__main__":
    main()

