import numpy as np 
import pickle
import torch 

from typing import List 
from transformers import Trainer 

from estimators import (CLUB, DoE, InfoNCE, KNIFE, MINE, NWJ, SMILE)
from infonce_new import InfoNCEPointwise

class Args:
    def __init__(self, **args):
        for k in args:
            setattr(self, k, args[k])


def auto_estimator_from_default_parameters(name, dim):
    """
    Inputs:
        name: the name of the estimator
        dim: the dimension of the data to estimate MI from
    """
    if name == "CLUB":
        args = Args(
            ff_residual_connection=True, 
            ff_layers=2, 
            ff_layer_norm=True, 
            ff_activation="relu", 
            use_tanh=True
        )
        return CLUB(args, dim, dim)
    elif name == "DoE":
        args = Args(
            ff_residual_connection=True,
            ff_layers=2,
            ff_layer_norm=False,
            ff_activation="relu"
        )
        return DoE(args, dim, dim)
    elif name == "InfoNCE":
        args = Args(
            ff_residual_connection=False,
            ff_layers=1,
            ff_layer_norm=True,
            ff_activation="relu"
        )
        return InfoNCE(args, dim, dim)
    elif name == "InfoNCE_Pointwise":
        # This returns pointwise results, but other estimators return batch-wise results
        args = Args(
            ff_residual_connection=False,
            ff_layers=1,
            ff_layer_norm=True,
            ff_activation="relu"
        )
        return InfoNCEPointwise(args, dim, dim)
    elif name == "MINE":
        args = Args(
            ff_residual_connection=False,
            ff_layers=2,
            ff_layer_norm=True,
            ff_activation="relu"
        )
        return MINE(args, dim, dim)
    elif name == "NWJ":
        args = Args(
            nwj_measure="W1",
            ff_residual_connection=False,
            ff_layers=1,
            ff_layer_norm=False,
            ff_activation="relu"
        )
        return NWJ(args, dim, dim)
    elif name == "SMILE":
        args = Args(
            clip=None,
            ff_residual_connection=False,
            ff_layers=1,
            ff_layer_norm=True,
            ff_activation="relu"
        )
        return SMILE(args, dim, dim)
    else:
        raise NotImplemented 


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, E):
        super().__init__()
        self.X = X 
        self.Y = Y 
        self.E = E 

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return {
            "X": self.X[i],
            "Y": self.Y[i],
            "E": self.E[i]
        }

def collate_function(batch, device):
    b = {}
    for key in ["X", "Y", "E"]:
        b[key] = torch.tensor([item[key] for item in batch]).to(device)
    return b 
    
def load_data(args, rawtext=False):
    # Please refer to prepare_data.py for the preprocessing of the X, Y, E data. 
    prepared_data_path = "../data/prepared/{}_{}_{}_{}.pkl".format(args.dataset, args.method, args.embedding, args.downsample)
    with open(prepared_data_path, "rb") as f:
        checkpoint = pickle.load(f)

    X_train, X_val, X_test = checkpoint["X_train"], checkpoint["X_val"], checkpoint["X_test"]
    Xs_train, Xs_val, Xs_test = checkpoint["Xs_train"], checkpoint["Xs_val"], checkpoint["Xs_test"]
    Y_train, Y_val, Y_test = checkpoint["Y_train"], checkpoint["Y_val"], checkpoint["Y_test"]
    E_train, E_val, E_test = checkpoint["E_train"], checkpoint["E_val"], checkpoint["E_test"]
    Es_train, Es_val, Es_test = checkpoint["Es_train"], checkpoint["Es_val"], checkpoint["Es_test"]
    
    if rawtext:
        train_ds = MyDataset(Xs_train, Y_train, Es_train)
        val_ds = MyDataset(Xs_val, Y_val, Es_val)
        test_ds = MyDataset(Xs_test, Y_test, Es_test)
    else:
        train_ds = MyDataset(X_train, Y_train, E_train)
        val_ds = MyDataset(X_val, Y_val, E_val)
        test_ds = MyDataset(X_test, Y_test, E_test)

    return train_ds, val_ds, test_ds 