import argparse 
import numpy as np 
import pandas as pd 
import pickle
import time
import torch 

from pathlib import Path 
from torch.utils.data import DataLoader

from mi_estimation_utils import load_data, collate_function
from hyperparameter_tuning_vinfo import F_Model, eval_forward


def compute_I_Y_E(all_data, f_Y_Q, f_Y_E, args):
    train_ds, val_ds, test_ds = all_data 

    # There are already the trained models (f_Y_Q and f_Y_E). Can proceed with computing the MIs
    val_dl = DataLoader(val_ds, args.batch_size, shuffle=False, collate_fn=lambda batch: collate_function(batch, args.device))
    test_dl = DataLoader(test_ds, args.batch_size, shuffle=False, collate_fn=lambda batch: collate_function(batch, args.device))
    
    h_Y_E_pointwise_val = eval_forward(f_Y_E, val_dl, average=False)
    h_Y_Q_pointwise_val = eval_forward(f_Y_Q, val_dl, average=False)
    vinfo_pointwise_val = h_Y_Q_pointwise_val - h_Y_E_pointwise_val 

    h_Y_E_pointwise_test = eval_forward(f_Y_E, test_dl, average=False)
    h_Y_Q_pointwise_test = eval_forward(f_Y_Q, test_dl, average=False)
    vinfo_pointwise_test = h_Y_Q_pointwise_test - h_Y_E_pointwise_test 
    
    report_df = pd.DataFrame({
        "h_Y_E_val": [h_Y_E_pointwise_val.mean()],
        "h_Y_Q_val": [h_Y_Q_pointwise_val.mean()],
        "h_Y_E_test": [h_Y_E_pointwise_test.mean()],
        "h_Y_Q_test": [h_Y_Q_pointwise_test.mean()],
        "vinfo_val": [vinfo_pointwise_val.mean()],
        "vinfo_test": [vinfo_pointwise_test.mean()]
    })
    return report_df, vinfo_pointwise_val, vinfo_pointwise_test


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, choices=["rationale", "nle"], default="nle")
    parser.add_argument("--dataset", type=str, default="esnli")
    parser.add_argument("--embedding", type=str, choices=["roberta", "openai", "cohere"], required=True)
    parser.add_argument("--downsample", type=int, default=1200)
    parser.add_argument("--estimator", type=str, default="vinfo", choices=["vinfo"])
    parser.add_argument("--dim", type=int, help="The dimensions of X and E. To be determined from data")
    parser.add_argument("--n_class", type=int, help="To be automatically decided from dataset")
    parser.add_argument("--YQcheckpoint", type=str, help="Need to manually prepare and place the best checkpoint pkl")
    parser.add_argument("--YEcheckpoint", type=str, help="Need to manually prepare and place the best checkpoint pkl")
    parser.add_argument("--batch_size", type=int, default=16, help="A dummy batch_size. Need this to instantiate a dataloader")
    parser.add_argument("--report_path", type=str, default="../reports/y_e_info_report.csv")
    parser.add_argument("--pointwise_export_path", type=str, default="../data/scored")
    args = parser.parse_args()
    args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    all_data = load_data(args)
    if args.embedding == "roberta" or args.embedding == "cohere":
        args.dim = 1024
    elif args.embedding == "openai":
        args.dim = 1536
    else:
        raise NotImplemented("Haven't implemented embedding method {} yet".format(args.embedding))
    args.n_class = {"esnli": 3}[args.dataset]
    args.YEcheckpoint = f"../checkpoints_best/{args.dataset}_{args.method}_{args.embedding}_{args.downsample}_f_Y_E/best.pkl"
    args.YQcheckpoint = f"../checkpoints_best/{args.dataset}_{args.method}_{args.embedding}_{args.downsample}_f_Y_Q/best.pkl"

    print(args)
    f_Y_Q = F_Model(args.dim, args.n_class)
    f_Y_Q.load_state_dict(torch.load(args.YQcheckpoint, map_location=args.device))
    
    f_Y_E = F_Model(args.dim, args.n_class)
    f_Y_E.load_state_dict(torch.load(args.YEcheckpoint, map_location=args.device))
    
    results_df, vinfo_pointwise_val, vinfo_pointwise_test = compute_I_Y_E(all_data, f_Y_Q, f_Y_E, args)
    for key in ["method", "dataset", "embedding", "downsample", "YQcheckpoint", "YEcheckpoint"]:
        results_df[key] = [vars(args)[key]]
    
    # Export the aggregate results
    p = Path(args.report_path)
    if not p.exists():
        results_df.to_csv(p, index=False)
    else:
        results_df.to_csv(p, index=False, header=False, mode="a")

    # Export the pointwise results
    p = Path(args.pointwise_export_path, "{}_{}_{}_{}_val_IYE.csv".format(
        args.dataset, args.method, args.downsample, args.embedding
    ))
    pointwise_df_val = pd.DataFrame({
        "v_info": vinfo_pointwise_val.tolist()
    })
    pointwise_df_val.to_csv(p, index=False)

    p = Path(args.pointwise_export_path, "{}_{}_{}_{}_test_IYE.csv".format(
        args.dataset, args.method, args.downsample, args.embedding
    ))
    pointwise_df_test = pd.DataFrame({
        "v_info": vinfo_pointwise_test.tolist()
    })
    pointwise_df_test.to_csv(p, index=False)