import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
import os
import torch
import argparse
from tools.utils import ModelNameParser
from models.s_model import S_SimDec
from loaders.s_loader import S_Loader
from environments.environment import Env
from sessions.cb_session import CB_Session
from tools import feature_list
import torch.nn.functional as F

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate predictor accuracy under perturbation by action category")
    parser.add_argument('--sim_ckpt', type=str, required=True, help='Path to predictor checkpoint')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name')
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--dm_lr", type=float, default=0.01)
    parser.add_argument("--use_calibration", type=bool, default=False)
    parser.add_argument("--use_perturbation", type=bool, default=True)
    parser.add_argument("--eta_c", type=float, default=1)
    parser.add_argument("--eta_p", type=float, default=1)
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size for evaluation')
    parser.add_argument("--random_noise", type=int, default=1, help='Add perturbation noise to input features')
    parser.add_argument("--epsilon_p", type=float, default=0.01, help='Noise strength for perturbation')
    parser.add_argument('--suffix', type=str, default='perturb_eval', help='Suffix for checkpoint naming')
    parser.add_argument('--CKPT_PATH', type=str, default='./', help='Checkpoint directory')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--use_gpu', action='store_true', help='Use GPU if available')
    parser.add_argument('--device_id', type=str, default='cuda:0', help='Device id for torch')
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    parser.add_argument("--decay_coeff", type=float, default=0.00001)
    return parser.parse_args()

def evaluate_under_perturbation(env, session, model, input_id, feature_dim, label_dim, epsilon_p, args):
    action_indices = input_id[:, feature_dim].long()
    action_to_indices = {a: (action_indices == a).nonzero(as_tuple=True)[0] for a in range(4)}
    correct_total_all = [0] * label_dim
    total_total_all = [0] * label_dim

    for action in range(4):
        indices = action_to_indices[action]
        if len(indices) == 0:
            continue
        action_input = input_id[indices]
        ori_input, processed_input_id, _ = session._prepare_input_data(action_input, mode='test')
        chunk_size = int(args.batch_size // 1.5)

        with torch.no_grad():
            for i in range(0, len(processed_input_id), chunk_size):
                input_chunk = processed_input_id[i: i + chunk_size]
                ori_chunk = ori_input[i: i + chunk_size]
                perturbed_input = input_chunk[:, :feature_dim]

                if args.random_noise:
                    noise = torch.randn_like(perturbed_input) * epsilon_p
                    perturbed_input = perturbed_input + noise.to(env.device)

                pred_tokens = model(
                    c_input=perturbed_input,
                    shipping_mode=ori_chunk[:, feature_dim].long(),
                    tgt=ori_chunk[:, feature_dim + 1:]
                )

                class_labels = ori_chunk[:, -label_dim:].long().to(env.device)
                for j in range(label_dim):
                    logits = pred_tokens[j]
                    y_true = class_labels[:, j]
                    pred = torch.argmax(logits, dim=1)
                    correct_total_all[j] += int((pred == y_true).sum().item())
                    total_total_all[j] += len(y_true)

    acc_all = [correct_total_all[j] / total_total_all[j] if total_total_all[j] > 0 else 0.0 for j in range(label_dim)]
    mean_acc = sum(acc_all) / len(acc_all) if acc_all else 0.0
    return acc_all, mean_acc

def main():
    import numpy as np
    args = parse_args()
    args.use_flag = True
    args.wandb = False
    args.use_calibration = False
    args.use_perturbation = bool(args.random_noise)

    env = Env(args)
    dataset = S_Loader(env)
    env.feature_classes = dataset.feature_classes

    model = S_SimDec(env).to(env.device)
    model_args = ModelNameParser().parse_name(args.sim_ckpt)
    assert model_args["mode"] == "sim"
    assert model_args["dataset"] == args.dataset
    model.load_state_dict(torch.load(args.sim_ckpt, map_location=env.device))
    model.eval()

    session = CB_Session(env, model, dataset)
    if hasattr(session, 'scaler') and hasattr(session, 'train_inputs'):
        session.scaler.fit(session.train_inputs)

    input_id = session._select_dataset_by_mode('test')
    feature_dim = session._get_feature_dim()
    label_dim = session._get_label_dim()

    epsilons = [round(i * 0.125, 3) for i in range(5)]  # [0.0, 0.125, 0.25, 0.375, 0.5]
    mean_acc_list = []

    print("===== Evaluating Under Varying Perturbation Levels =====")
    for epsilon_p in epsilons:
        acc_all, mean_acc = evaluate_under_perturbation(
            env, session, model, input_id, feature_dim, label_dim, epsilon_p, args
        )
        acc_fmt = [f"{a:.4f}" for a in acc_all]
        print(f"[epsilon_p={epsilon_p:.3f}] Accuracy per label: {acc_fmt}")
        print(f"[epsilon_p={epsilon_p:.3f}] Mean accuracy: {mean_acc:.4f}\n")
        mean_acc_list.append(mean_acc)

    # ===== Robustness Metrics =====
    acc_clean = mean_acc_list[0]
    acc_perturbed = mean_acc_list[1:]
    avg_perturbed_acc = sum(acc_perturbed) / len(acc_perturbed)
    worst_acc = min(acc_perturbed)
    acc_var = np.var(acc_perturbed)
    drop_rate = acc_clean - avg_perturbed_acc

    print("========== Robustness Metrics ==========")
    print(f"Clean Accuracy (ε=0):        {acc_clean:.4f}")
    print(f"Avg Perturbed Accuracy:      {avg_perturbed_acc:.4f}")
    print(f"Worst-case Accuracy:         {worst_acc:.4f}")
    print(f"Stability (Variance):        {acc_var:.6f}")
    print(f"Drop Rate (ΔAcc):            {drop_rate:.4f}")
    
if __name__ == "__main__":
    main()
