import argparse
import numpy as np
import optuna 
import pickle
import torch 
import transformers 

from pathlib import Path 
from torch.utils.data import DataLoader

from mi_estimation_utils import load_data, collate_function


class F_Model(torch.nn.Module):
    def __init__(self, input_dim, n_class):
        super().__init__()
        self.fc = torch.nn.Linear(input_dim, n_class)
        self.act_fn = torch.nn.LogSoftmax(dim=-1)

    def forward(self, X):
        res = self.fc(X.float())
        res = self.act_fn(res)
        return res 

    def loss(self, X, Y):
        probs = self.forward(X)
        return torch.nn.NLLLoss(reduction='none')(probs, Y.long())


def eval_forward(f, dl, average=True):
    f.eval()
    result = []
    for b in dl:
        batch_X, batch_Y, batch_E = b["X"], b["Y"], b["E"]
        loss = f.loss(batch_E, batch_Y)
        result.extend(loss.detach().tolist())
    if average:
        return np.mean(result)
    else:
        return np.array(result)
    

def train_model(all_data, args, trial):
    # List all hyperparameters to tune here
    batch_sizes = [4, 8, 16]
    batch_size = batch_sizes[trial.suggest_int("batch_size_choice", low=0, high=2, step=1)]
    lr = trial.suggest_float("lr", low=1e-6, high=1e-2, log=True)
    max_epochs = 10

    # Train the model
    train_ds, val_ds, test_ds = all_data 
    train_dl = DataLoader(train_ds, batch_size, shuffle=True, collate_fn=lambda batch: collate_function(batch, args.device))
    val_dl = DataLoader(val_ds, batch_size, shuffle=False, collate_fn=lambda batch: collate_function(batch, args.device))
    test_dl = DataLoader(test_ds, batch_size, shuffle=False, collate_fn=lambda batch: collate_function(batch, args.device))

    f = F_Model(
        input_dim=args.dim,
        n_class=args.n_class
    )
    f.to(args.device)
    optim = torch.optim.Adam(f.parameters(), lr=lr)

    val_losses = []
    test_losses = []
    min_vloss = np.inf 
    best_model_checkpoint = None 
    for epoch in range(max_epochs):
        f.train()
        for b in train_dl:
            batch_X, batch_Y, batch_E = b["X"], b["Y"], b["E"]
            loss = f.loss(batch_E, batch_Y).mean()
            loss.backward()
        
        f.eval()
        vloss = eval_forward(f, val_dl)
        val_losses.append(vloss)
        if vloss < min_vloss:
            best_model_checkpoint = f.state_dict()
        test_losses.append(eval_forward(f, test_dl))

    fname = f"{args.export_dir}/bsz{batch_size}_lr{lr:.6f}.pkl"
    torch.save(best_model_checkpoint, fname)

    best_epoch_id = np.argmin(val_losses)
    return val_losses[best_epoch_id]


def randomize_E(all_data):
    train_ds, val_ds, test_ds = all_data 
    for ds in all_data:
        shape = np.array(ds.E).shape
        new_E = np.random.normal(0, 0.1, shape)
        ds.E = new_E 
    return train_ds, val_ds, test_ds


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, choices=["rationale", "nle"], default="nle")
    parser.add_argument("--dataset", type=str, default="esnli")
    parser.add_argument("--embedding", type=str, choices=["roberta", "openai", "cohere"])
    parser.add_argument("--downsample", type=int, default=1200)
    parser.add_argument("--dim", type=int, help="To be automatically decided from embedding")
    parser.add_argument("--n_class", type=int, help="To be automatically decided from dataset")
    parser.add_argument("--option", type=str, choices=["f_Y_E", "f_Y_Q"])
    parser.add_argument("--export_dir", type=str, help="To be filled in automatically")
    parser.add_argument("--report", type=str, default="../reports/tune_f_Y_E")
    args = parser.parse_args()
    args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    args.export_dir = f"../checkpoints/{args.dataset}_{args.method}_{args.embedding}_{args.downsample}_{args.option}"
    Path(args.export_dir).mkdir(parents=True, exist_ok=True)

    args.report = f"../reports/tune_{args.dataset}_{args.method}_{args.embedding}_{args.downsample}_{args.option}.csv"

    if args.embedding == "roberta" or args.embedding == "cohere":
        args.dim = 1024
    elif args.embedding == "openai":
        args.dim = 1536
    else:
        raise NotImplemented("Haven't implemented embedding method {} yet".format(args.embedding))
    args.n_class = {"esnli": 3}[args.dataset]

    print(args)
    all_data = load_data(args)
    if args.option == "f_Y_Q":
        all_data = randomize_E(all_data)
        print("Randomized the E data")

    study = optuna.create_study(direction="minimize")
    study.optimize(lambda trial: train_model(all_data, args, trial), n_trials=50)
    print(study.best_params)

    df = study.trials_dataframe()
    df.to_csv(args.report)