import argparse 
import numpy as np 
import pandas as pd 
import pickle
import time
import torch 

from pathlib import Path 
from typing import List
from torch.utils.data import DataLoader
from mi_estimation_utils import auto_estimator_from_default_parameters, load_data, collate_function


def evaluate_forward_pointwise(estimator, dl, compute_loss=False) -> List:
    estimator.eval()
    result = []
    for b in dl:
        batch_X, batch_Y, batch_E = b["X"], b["Y"], b["E"]
        if compute_loss:
            loss = estimator.learning_loss(batch_X, batch_E)
            result.extend(loss.detach().numpy().tolist())
        else:
            I_X_E = estimator(batch_X, batch_E)[0]
            result.extend(I_X_E.detach().numpy().tolist())
    return result

def evaluate_forward_batchwise(estimator, dl, compute_loss=False) -> List:
    estimator.eval()
    result = []
    for b in dl:
        batch_X, batch_Y, batch_E = b["X"], b["Y"], b["E"]
        if compute_loss:
            loss = estimator.learning_loss(batch_X, batch_E)
            result.append(loss.item())
        else:
            I_X_E = estimator(batch_X, batch_E)[0]
            result.append(I_X_E.item())
    return result 

def compute_I_X_E(all_data, args):
    """
    I(X;E)
    Following the pre-selected hyperparameter, train the estimator to compute the I(X;E) values.
    """
    train_ds, val_ds, test_ds = all_data
    train_dl = DataLoader(train_ds, args.batch_size, shuffle=True, collate_fn=lambda batch: collate_function(batch, args.device))
    val_dl = DataLoader(val_ds, args.batch_size, shuffle=False, collate_fn=lambda batch: collate_function(batch, args.device))
    test_dl = DataLoader(test_ds, args.batch_size, shuffle=False, collate_fn=lambda batch: collate_function(batch, args.device))
    
    estim_X_E = auto_estimator_from_default_parameters(args.estimator, args.dim) 
    estim_X_E.to(args.device)
    optim_X_E = torch.optim.Adam(estim_X_E.parameters(), lr=args.lr)

    val_losses, val_mis, test_mis, steps = [], [], [], []  # List (len n_epoch)
    val_mis_var = []
    test_mis_pointwise = []

    if args.estimator.endswith("_Pointwise"):
        evaluate_forward = evaluate_forward_pointwise 
    else:
        evaluate_forward = evaluate_forward_batchwise

    step = 0
    for epoch in range(args.train_epochs):
        estim_X_E.train()

        for b in train_dl:
            batch_X, batch_Y, batch_E = b["X"], b["Y"], b["E"]
            
            # I(X;E)
            loss = estim_X_E.learning_loss(batch_X, batch_E).mean()
            loss.backward()
            optim_X_E.step()
            optim_X_E.zero_grad()

            step += 1

        val_losses.append(np.mean(evaluate_forward(estim_X_E, val_dl, compute_loss=True)))
        epoch_val_mis = evaluate_forward(estim_X_E, val_dl, compute_loss=False)
        val_mis.append(np.mean(epoch_val_mis))
        val_mis_var.append(np.var(epoch_val_mis))

        epoch_test_mis = evaluate_forward(estim_X_E, test_dl, compute_loss=False)
        test_mis.append(np.mean(epoch_test_mis))
        test_mis_pointwise.append(epoch_test_mis)
        steps.append(step)
        
    best_epoch_id = np.argmin(val_losses)
    report_df = pd.DataFrame({
        "steps": [steps[best_epoch_id]],
        "best_val_loss": [val_losses[best_epoch_id]],
        "best_val_mi": [val_mis[best_epoch_id]],
        "best_test_mi": [test_mis[best_epoch_id]],
        "val_mi_var": [val_mis_var[best_epoch_id]]
    })
    pointwise_df = pd.DataFrame({
        "I_X_E": test_mis_pointwise[best_epoch_id]
    })
    return report_df, pointwise_df

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, choices=["shap", "rationale", "nle"], default="nle")
    parser.add_argument("--dataset", type=str, default="esnli")
    parser.add_argument("--embedding", type=str, choices=["roberta", "openai", "cohere"], required=True)
    parser.add_argument("--downsample", type=int, default=1200)
    parser.add_argument("--estimator", type=str, default="InfoNCE_Pointwise")
    parser.add_argument("--dim", type=int, help="The dimensions of X and E. To be determined from data")
    parser.add_argument("--lr", type=float, default=0.00003)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--train_epochs", type=int, default=10)
    parser.add_argument("--report_path", type=str, default="../reports/mi_report.csv")
    parser.add_argument("--pointwise_export_dir", type=str, default="../data/scored/", help="This only need to be run once, after the best config has been selected")
    parser.add_argument("--export_pointwise", action="store_true", help="Save a dataframe to args.pointwise_export_dir/{}_{}_{}_{}_{}.csv")
    args = parser.parse_args()
    args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(args) 
    start_time = time.time()

    all_data = load_data(args)
    if args.embedding == "roberta" or args.embedding == "cohere":
        args.dim = 1024
    elif args.embedding == "openai":
        args.dim = 1536
    else:
        raise NotImplemented("Haven't implemented embedding method {} yet".format(args.embedding))

    results_df, pointwise_df = compute_I_X_E(all_data, args)
    for k in ["batch_size", "lr", "dim", "embedding", "downsample", "method", "dataset", "estimator"]:
        results_df[k] = [vars(args)[k]] 
    results_df["runtime"] = [time.time() - start_time] 
    
    p = Path(args.report_path)
    if not p.exists():
        results_df.to_csv(p, index=False)
    else:
        results_df.to_csv(p, index=False, header=False, mode="a")
        
    if args.export_pointwise:
        _, _, test_ds = load_data(args, rawtext=True)
        pointwise_df["X"] = test_ds.X 
        pointwise_df["E"] = test_ds.E 

        p = Path(args.pointwise_export_dir, f"{args.dataset}_{args.method}_{args.downsample}_{args.embedding}_test_IXE.csv")
        pointwise_df.to_csv(p, index=False)