import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
import os
import torch
import argparse
from models.s_model import S_SimDec
from models.v_model import ValueNetwork
from models.recalibrator import Recalibrator
from loaders.s_loader import S_Loader
from environments.environment import Env
from sessions.cb_session import CB_Session
from tools import feature_list

# ===================== Argument Parsing =====================
def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate b-weighted accuracy of Decision Maker with Calibrated Predictor")
    parser.add_argument('--calibrated_predictor_ckpt', type=str, required=True, help='Path to calibrated predictor checkpoint')
    parser.add_argument('--uncalibrated_predictor_ckpt', type=str, required=True, help='Path to uncalibrated predictor checkpoint')
    parser.add_argument('--dm_ckpt', type=str, required=True, help='Path to trained decision maker checkpoint')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset name')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size for evaluation')
    parser.add_argument("--embed_dim", type=int, default=64)
    parser.add_argument("--decoder_num_layers", type=int, default=1)
    parser.add_argument("--encoder_num_layers", type=int, default=1)
    parser.add_argument("--decay_coeff", type=float, default=0.00001)
    parser.add_argument("--dm_decay_coeff", type=float, default=0.0005)
    parser.add_argument("--mi_coeff", type=float, default=10)
    parser.add_argument("--ma_coeff", type=float, default=1)
    parser.add_argument("--otr_reward_coeff", type=float, default=1)
    parser.add_argument("--reward_smoothing_factor", type=float, default=0.5)
    parser.add_argument("--mip_coeff", type=float, default=1)
    parser.add_argument("--mil_coeff", type=float, default=1)
    parser.add_argument("--wandb", type=int, default=0)
    parser.add_argument("--save", type=int, default=0)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--dm_lr", type=float, default=0.01)
    parser.add_argument("--use_calibration", type=bool, default=True)
    parser.add_argument("--use_perturbation", type=bool, default=False)
    parser.add_argument("--eta_c", type=float, default=1)
    parser.add_argument("--eta_p", type=float, default=1)
    parser.add_argument('--dm_epochs', type=int, default=1, help='Dummy for compatibility')
    parser.add_argument('--eva_interval', type=int, default=1, help='Dummy for compatibility')
    parser.add_argument('--early_stop', type=int, default=10, help='Dummy for compatibility')
    parser.add_argument('--suffix', type=str, default='eval', help='Suffix for checkpoint naming')
    parser.add_argument('--CKPT_PATH', type=str, default='./', help='Checkpoint directory')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--use_gpu', action='store_true', help='Use GPU if available')
    parser.add_argument('--device_id', type=str, default='cuda:0', help='Device id for torch')
    return parser.parse_args()

# ===================== Main Evaluation Script =====================
def main():
    args = parse_args()
    args.epochs = 0  # For compatibility
    args.ckpt_start_epoch = 0  # For compatibility
    args.wandb = False  # Disable wandb logging for evaluation

    # Set up environment and dataset
    env = Env(args)
    dataset = S_Loader(env)
    env.feature_classes = dataset.feature_classes

    # Build and load predictor (S_SimDec) - calibrated
    model_calibrated = S_SimDec(env).to(env.device)
    model_calibrated.load_state_dict(torch.load(args.calibrated_predictor_ckpt, map_location=env.device))
    model_calibrated.eval()

    # Build and load predictor (S_SimDec) - uncalibrated
    model_uncalibrated = S_SimDec(env).to(env.device)
    model_uncalibrated.load_state_dict(torch.load(args.uncalibrated_predictor_ckpt, map_location=env.device))
    model_uncalibrated.eval()

    # If calibration is used, build recalibrator and attach to session
    recalibrator = None
    if args.use_calibration:
        feature_dim = len(
            feature_list.product_info[args.dataset]
            + feature_list.order_info[args.dataset]
            + feature_list.customer_info[args.dataset]
            + feature_list.shipping_info[args.dataset]
        )
        action_dim = 4  # Should match your setup
        recalibrator = Recalibrator(K=action_dim, input_dim=feature_dim).to(env.device)
        recalibrator.eval()
        # Optionally load recalibrator weights if you have them
        # recalibrator.load_state_dict(torch.load(args.recalibrator_ckpt, map_location=env.device))

    # Build and load decision maker (ValueNetwork)
    value_network = ValueNetwork(env).to(env.device)
    value_network.load_state_dict(torch.load(args.dm_ckpt, map_location=env.device))
    value_network.eval()

    # Set up CB_Session for data utilities
    session = CB_Session(env, model_calibrated, dataset)
    session.init_value_network(value_network)
    assert recalibrator is not None, "Recalibrator is not initialized"
    session.recalibrator = recalibrator
    session.use_calibration = True

    # ===== Ensure StandardScaler is fitted before evaluation =====
    if hasattr(session, 'scaler') and hasattr(session, 'train_inputs'):
        session.scaler.fit(session.train_inputs)

    # Evaluate b-weighted accuracy and normal accuracy for both predictors
    print("Evaluating b-weighted accuracy of Calibrated and Uncalibrated Predictors...")
    for mode in ["val", "test"]:
        # Prepare data
        input_id = session._select_dataset_by_mode(mode)
        ori_input, processed_input_id, feature_dim = session._prepare_input_data(input_id, mode)
        label_dim = session._get_label_dim()
        chunk_size = int(args.batch_size // 1.5)

        b_weighted_correct_sum_cal = [0.0] * label_dim
        b_weighted_total_weight_cal = [0.0] * label_dim
        b_weighted_correct_sum_uncal = [0.0] * label_dim
        b_weighted_total_weight_uncal = [0.0] * label_dim
        correct_preds_uncal = [0] * label_dim
        total_samples = [0] * label_dim
        correct_preds_cal = [0] * label_dim  # For calibrated predictor normal accuracy

        import time
        t = time.time()
        with torch.no_grad():
            for i in range(0, len(processed_input_id), chunk_size):
                input_chunk = processed_input_id[i : i + chunk_size]
                ori_chunk = ori_input[i : i + chunk_size]
                class_labels = ori_chunk[:, -label_dim:].long().to(env.device)

                # Calibrated predictor forward
                predicted_tokens_cal = model_calibrated(
                    c_input=input_chunk[:, :feature_dim],
                    shipping_mode=ori_chunk[:, feature_dim].long(),
                    tgt=ori_chunk[:, feature_dim + 1 :],
                )
                # Uncalibrated predictor forward
                predicted_tokens_uncal = model_uncalibrated(
                    c_input=input_chunk[:, :feature_dim],
                    shipping_mode=ori_chunk[:, feature_dim].long(),
                    tgt=ori_chunk[:, feature_dim + 1 :],
                )

                for j in range(label_dim):
                    logits_cal = predicted_tokens_cal[j]
                    logits_uncal = predicted_tokens_uncal[j]
                    y_true = class_labels[:, j]
                    # Calibrated prediction
                    pred_cal = torch.argmax(logits_cal, dim=1)
                    # Uncalibrated prediction
                    pred_uncal = torch.argmax(logits_uncal, dim=1)
                    # b_xa from calibrated predictor
                    y_true_onehot = torch.nn.functional.one_hot(y_true, num_classes=logits_cal.shape[1]).float()
                    p_prev_cal = torch.softmax(logits_cal, dim=-1)
                    assert recalibrator is not None, "Recalibrator is not initialized"
                    p_calibrated, info = recalibrator(
                        x=input_chunk[:, :feature_dim],
                        p_prev=p_prev_cal,
                        y_true=y_true_onehot,
                    )
                    b_xa = info["b"]  # [B, K]
                    # Main direction for each sample
                    main_dir = b_xa.argmax(dim=1)  # [B]
                    b_main = b_xa[torch.arange(b_xa.size(0)), main_dir]  # [B]
                    # Calibrated b-weighted accuracy (main direction)
                    is_correct_cal = (pred_cal == y_true).float()
                    b_weighted_correct_cal = is_correct_cal * b_main  # [B]
                    b_weighted_correct_sum_cal[j] += b_weighted_correct_cal.sum().item()
                    b_weighted_total_weight_cal[j] += b_main.sum().item()
                    # Calibrated normal accuracy
                    correct_preds_cal[j] += int((pred_cal == y_true).sum().item())
                    # Uncalibrated b-weighted accuracy (main direction)
                    is_correct_uncal = (pred_uncal == y_true).float()
                    b_weighted_correct_uncal = is_correct_uncal * b_main  # [B]
                    b_weighted_correct_sum_uncal[j] += b_weighted_correct_uncal.sum().item()
                    b_weighted_total_weight_uncal[j] += b_main.sum().item()
                    # Uncalibrated normal accuracy
                    correct_preds_uncal[j] += int((pred_uncal == y_true).sum().item())
                    total_samples[j] += len(y_true)
        eval_time = time.time() - t
        b_weighted_accuracy_cal = [
            b_weighted_correct_sum_cal[j] / b_weighted_total_weight_cal[j] if b_weighted_total_weight_cal[j] > 0 else 0.0
            for j in range(label_dim)
        ]
        b_weighted_accuracy_uncal = [
            b_weighted_correct_sum_uncal[j] / b_weighted_total_weight_uncal[j] if b_weighted_total_weight_uncal[j] > 0 else 0.0
            for j in range(label_dim)
        ]
        accuracy_uncal = [
            correct_preds_uncal[j] / total_samples[j] if total_samples[j] > 0 else 0.0
            for j in range(label_dim)
        ]
        accuracy_cal = [
            correct_preds_cal[j] / total_samples[j] if total_samples[j] > 0 else 0.0
            for j in range(label_dim)
        ]
        # Compute mean for each accuracy list
        mean_b_weighted_accuracy_cal = sum(b_weighted_accuracy_cal) / len(b_weighted_accuracy_cal) if b_weighted_accuracy_cal else 0.0
        mean_accuracy_cal = sum(accuracy_cal) / len(accuracy_cal) if accuracy_cal else 0.0
        mean_b_weighted_accuracy_uncal = sum(b_weighted_accuracy_uncal) / len(b_weighted_accuracy_uncal) if b_weighted_accuracy_uncal else 0.0
        mean_accuracy_uncal = sum(accuracy_uncal) / len(accuracy_uncal) if accuracy_uncal else 0.0
        print(f"[{mode.upper()}] Calibrated predictor b-weighted accuracy per label: {b_weighted_accuracy_cal}")
        print(f"[{mode.upper()}] Calibrated predictor b-weighted accuracy mean: {mean_b_weighted_accuracy_cal:.4f}")
        print(f"[{mode.upper()}] Calibrated predictor normal accuracy per label: {accuracy_cal}")
        print(f"[{mode.upper()}] Calibrated predictor normal accuracy mean: {mean_accuracy_cal:.4f}")
        print(f"[{mode.upper()}] Uncalibrated predictor b-weighted accuracy per label: {b_weighted_accuracy_uncal}")
        print(f"[{mode.upper()}] Uncalibrated predictor b-weighted accuracy mean: {mean_b_weighted_accuracy_uncal:.4f}")
        print(f"[{mode.upper()}] Uncalibrated predictor normal accuracy per label: {accuracy_uncal}")
        print(f"[{mode.upper()}] Uncalibrated predictor normal accuracy mean: {mean_accuracy_uncal:.4f}")
        print(f"[{mode.upper()}] Evaluation time: {eval_time:.2f}s")

if __name__ == "__main__":
    main() 