import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
import os
import torch
import argparse
from tools.utils import ModelNameParser
from models.s_model import S_SimDec
from loaders.s_loader import S_Loader
from environments.environment import Env
from sessions.cb_session import CB_Session
from tools import feature_list
import torch.nn.functional as F

# ===================== Argument Parsing =====================
def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate sim accuracy by action category on test set for both calibrated and uncalibrated models")
    parser.add_argument('--calibrated_sim_ckpt', type=str, required=True, help='Path to calibrated sim checkpoint')
    parser.add_argument('--uncalibrated_sim_ckpt', type=str, required=True, help='Path to uncalibrated sim checkpoint')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size for evaluation')
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--dm_lr", type=float, default=0.01)
    parser.add_argument("--use_calibration", type=bool, default=False)
    parser.add_argument("--use_perturbation", type=bool, default=True)
    parser.add_argument("--eta_c", type=float, default=1)
    parser.add_argument("--eta_p", type=float, default=1)
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    parser.add_argument("--decay_coeff", type=float, default=0.00001)
    parser.add_argument('--suffix', type=str, default='eval', help='Suffix for checkpoint naming')
    parser.add_argument('--CKPT_PATH', type=str, default='./', help='Checkpoint directory')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--use_gpu', action='store_true', help='Use GPU if available')
    parser.add_argument('--device_id', type=str, default='cuda:0', help='Device id for torch')
    return parser.parse_args()

# ===================== Main Evaluation Script =====================
def main():
    model_name_parser = ModelNameParser()
    args = parse_args()
    args.epochs = 0  # For compatibility
    args.ckpt_start_epoch = 0  # For compatibility
    args.wandb = False  # Disable wandb logging for evaluation

    # Set up environment and dataset
    env = Env(args)
    dataset = S_Loader(env)
    env.feature_classes = dataset.feature_classes

    # Build and load both sims (S_SimDec)
    model_calibrated = S_SimDec(env).to(env.device)
    model_args_calibrated = model_name_parser.parse_name(args.calibrated_sim_ckpt)
    assert model_args_calibrated["mode"] == "sim", f"wrong model mode {model_args_calibrated['mode']}"
    assert model_args_calibrated["use_flag"] == True, "uncalibrated sim loaded to calibrated test position"
    assert model_args_calibrated["dataset"] == args.dataset, f"wrong dataset for model {args.calibrated_sim_ckpt}"
    model_calibrated.load_state_dict(torch.load(args.calibrated_sim_ckpt, map_location=env.device))
    model_calibrated.eval()

    model_uncalibrated = S_SimDec(env).to(env.device)
    model_args_uncalibrated = model_name_parser.parse_name(args.uncalibrated_sim_ckpt)
    assert model_args_uncalibrated["mode"] == "sim", f"wrong model mode {model_args_uncalibrated['mode']}"
    assert model_args_uncalibrated["use_flag"] == False, "calibrated sim loaded to uncalibrated test position"
    assert model_args_uncalibrated["dataset"] == args.dataset, f"wrong dataset for model {args.uncalibrated_sim_ckpt}"
    model_uncalibrated.load_state_dict(torch.load(args.uncalibrated_sim_ckpt, map_location=env.device))
    model_uncalibrated.eval()

    # Set up CB_Session for data utilities (use calibrated session for scaler)
    session = CB_Session(env, model_calibrated, dataset)

    # ===== Ensure StandardScaler is fitted before evaluation =====
    if hasattr(session, 'scaler') and hasattr(session, 'train_inputs'):
        session.scaler.fit(session.train_inputs)

    # Prepare test data
    input_id = session._select_dataset_by_mode('test')
    feature_dim = session._get_feature_dim()
    label_dim = session._get_label_dim()
    action_indices = input_id[:, feature_dim].long()  # [N]
    label_names = feature_list.label[args.dataset]

    # Split indices by action (0, 1, 2, 3)
    action_to_indices = {a: (action_indices == a).nonzero(as_tuple=True)[0] for a in range(4)}

    print("Evaluating accuracy by action category on TEST set...")
    for action in range(4):
        indices = action_to_indices[action]
        if len(indices) == 0:
            print(f"[ACTION {action}] No samples found.")
            continue
        action_input = input_id[indices]
        ori_input, processed_input_id, feature_dim = session._prepare_input_data(action_input, mode='test')
        chunk_size = int(args.batch_size // 1.5)
        # For each model, accumulate correct and total per label
        correct_cal = [0] * label_dim
        total_cal = [0] * label_dim
        correct_uncal = [0] * label_dim
        total_uncal = [0] * label_dim
        with torch.no_grad():
            for i in range(0, len(processed_input_id), chunk_size):
                input_chunk = processed_input_id[i : i + chunk_size]
                ori_chunk = ori_input[i : i + chunk_size]
                # Forward pass for both models
                predicted_tokens_cal = model_calibrated(
                    c_input=input_chunk[:, :feature_dim],
                    shipping_mode=ori_chunk[:, feature_dim].long(),
                    tgt=ori_chunk[:, feature_dim + 1 :],
                )
                predicted_tokens_uncal = model_uncalibrated(
                    c_input=input_chunk[:, :feature_dim],
                    shipping_mode=ori_chunk[:, feature_dim].long(),
                    tgt=ori_chunk[:, feature_dim + 1 :],
                )
                class_labels = ori_chunk[:, -label_dim:].long().to(env.device)
                for j in range(label_dim):
                    logits_cal = predicted_tokens_cal[j]
                    logits_uncal = predicted_tokens_uncal[j]
                    y_true = class_labels[:, j]
                    pred_cal = torch.argmax(logits_cal, dim=1)
                    pred_uncal = torch.argmax(logits_uncal, dim=1)
                    correct_cal[j] += int((pred_cal == y_true).sum().item())
                    total_cal[j] += len(y_true)
                    correct_uncal[j] += int((pred_uncal == y_true).sum().item())
                    total_uncal[j] += len(y_true)
        acc_cal = [correct_cal[j] / total_cal[j] if total_cal[j] > 0 else 0.0 for j in range(label_dim)]
        acc_uncal = [correct_uncal[j] / total_uncal[j] if total_uncal[j] > 0 else 0.0 for j in range(label_dim)]
        mean_acc_cal = sum(acc_cal) / len(acc_cal) if acc_cal else 0.0
        mean_acc_uncal = sum(acc_uncal) / len(acc_uncal) if acc_uncal else 0.0
        print(f"[ACTION {action}] Calibrated accuracy per label: {[f'{a:.4f}' for a in acc_cal]}")
        print(f"[ACTION {action}] Uncalibrated accuracy per label: {[f'{a:.4f}' for a in acc_uncal]}")
        print(f"[ACTION {action}] Calibrated mean accuracy: {mean_acc_cal:.4f} | Uncalibrated mean accuracy: {mean_acc_uncal:.4f} | N={sum(total_cal)}\n")

if __name__ == "__main__":
    main() 