import sys 
sys.path.insert(0, "../")
import argparse 
import numpy as np 
import random 
import torch 
import optuna
import pickle 

from estimators import (CLUB, DoE, InfoNCE, KNIFE, MINE, NWJ, SMILE)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, to_cuda=False, cubic=False):
    """Generate samples from a correlated Gaussian distribution."""
    mean = [0, 0]
    cov = [[1.0, rho], [rho, 1.0]]
    x, y = np.random.multivariate_normal(mean, cov, batch_size * dim).T

    x = x.reshape(-1, dim)
    y = y.reshape(-1, dim)

    if cubic:
        y = y ** 3

    if to_cuda:
        x = torch.from_numpy(x).float().cuda()
        # x = torch.cat([x, torch.randn_like(x).cuda() * 0.3], dim=-1)
        y = torch.from_numpy(y).float().cuda()
    return x, y

class Args:
    def __init__(self, **args):
        for k in args:
            setattr(self, k, args[k])

def rho_to_mi(rho, dim):
    result = -dim / 2 * np.log(1 - rho ** 2)
    return result
    
def mi_to_rho(mi, dim):
    result = np.sqrt(1 - np.exp(-2 * mi / dim))
    return result

def main(estimator, trial):
    cubic = False
    training_steps = trial.suggest_int("training_steps", low=100, high=200, step=100)
    batch_size = trial.suggest_int("batch_size", low=8, high=16, step=8)
    sample_dim = 1024
    learning_rate = trial.suggest_float("learning_rate", low=1e-5, high=1e-3, log=True)
    mi_list = [2., 4., 6., 8., 10.]
    all_random_seeds = [0, 1, 2, 31, 42]

    set_seed(1)

    if estimator == "CLUB":
        args = Args(
            ff_residual_connection=trial.suggest_int("ff_residual_connection", low=0, high=1, step=1),
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=trial.suggest_int("ff_layer_norm", low=0, high=1, step=1),
            ff_activation="relu",
            use_tanh=True
        )
        model = CLUB(args, zc_dim=sample_dim, zd_dim=sample_dim)
    elif estimator == "DoE":
        args = Args(
            ff_residual_connection=trial.suggest_int("ff_residual_connection", low=0, high=1, step=1),
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=trial.suggest_int("ff_layer_norm", low=0, high=1, step=1),
            ff_activation="relu"
        )
        model = DoE(args, zc_dim=sample_dim, zd_dim=sample_dim)
    elif estimator == "InfoNCE":
        args = Args(
            ff_residual_connection=False,  
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=True,
            ff_activation="relu",
        )
        model = InfoNCE(args, zc_dim=sample_dim, zd_dim=sample_dim)
    elif estimator == "KNIFE":
        print("Warning: Sometimes there are memory issues, but this problem can't be reliably reproduced.")
        modes_list = [64, 128, 256]
        args = Args(
            batch_size=batch_size,
            optimize_mu=trial.suggest_int("optimize_mu", low=0, high=1, step=1),
            marg_modes=modes_list[trial.suggest_int("marg_modes_choice", low=0, high=len(modes_list)-1, step=1)],
            cond_modes=modes_list[trial.suggest_int("cond_modes_choice", low=0, high=len(modes_list)-1, step=1)],
            use_tanh=True,
            init_std=trial.suggest_float("init_std", low=1e-4, high=1e-3, log=True),
            cov_diagonal="var",
            cov_off_diagonal="var",
            average="var",
            ff_residual_connection=True,
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=True,
            ff_activation="relu",
        )
        model = KNIFE(args, zc_dim=sample_dim, zd_dim=sample_dim)
    elif estimator == "MINE":
        args = Args(
            ff_residual_connection=False,
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=trial.suggest_int("ff_layer_norm", low=0, high=1, step=1),
            ff_activation="relu"
        )
        model = MINE(args, zc_dim=sample_dim, zd_dim=sample_dim)
    elif estimator == "NWJ":
        possible_nwj_measures = ["GAN", "JSD", "X2", "KL", "RKL", "DV", "H2", "W1"]
        args = Args(
            nwj_measure=possible_nwj_measures[trial.suggest_int("nwj_measure_choice", low=0, high=len(possible_nwj_measures)-1, step=1)],
            ff_residual_connection=False,
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=trial.suggest_int("ff_layer_norm", low=0, high=1, step=1),
            ff_activation="relu"
        )
        model = NWJ(args, zc_dim=sample_dim, zd_dim=sample_dim)
    elif estimator == "SMILE":
        possible_clip_vals = [None, 0.1, 1, 10]
        args = Args(
            clip=possible_clip_vals[trial.suggest_int("clip_choice", low=0, high=len(possible_clip_vals)-1)],
            ff_residual_connection=False,  
            ff_layers=trial.suggest_int("ff_layers", low=1, high=3, step=1),
            ff_layer_norm=trial.suggest_int("ff_layer_norm", low=0, high=1, step=1),
            ff_activation="relu"
        )
        model = SMILE(args, zc_dim=sample_dim, zd_dim=sample_dim)
    else:
        raise NotImplemented
    if torch.cuda.is_available():
        model.cuda()

    mi_est_values = []
    mi_ground_truth_values = []
    for rs in all_random_seeds:
        np.random.seed(rs)
        optimizer = torch.optim.Adam(model.parameters(), learning_rate)
        for mi_value in mi_list:
            rho = mi_to_rho(mi_value, sample_dim)
            for step in range(training_steps):
                batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size=batch_size, to_cuda=torch.cuda.is_available(), cubic=cubic)
                
                model.eval()
                if step == training_steps-1:
                    mi_est_values.append(model(batch_x, batch_y)[0].item())
                    mi_ground_truth_values.append(mi_value)

                model.train()
                model_loss = model.learning_loss(batch_x, batch_y)
                optimizer.zero_grad()
                model_loss.backward()
                optimizer.step()
    # MSE of MI estimation and add to the tune report
    mean_loss=((np.array(mi_est_values) - np.array(mi_ground_truth_values))**2).mean()
    return mean_loss 


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--estimator", type=str, choices=["CLUB", "DoE", "InfoNCE", "KNIFE", "MINE", "NWJ", "SMILE"], required=True)
    parser.add_argument("--report", type=str, default="optuna_report.csv")
    args = parser.parse_args()
    print(args)
    
    study = optuna.create_study(direction="minimize")
    study.optimize(lambda trial: main(args.estimator, trial), n_trials=200)
    print(study.best_params)

    df = study.trials_dataframe()
    df.to_csv(args.report)
