import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.autograd as autograd
import numpy as np
import time 

from utils import utils_os

from datasets import NonlinearGaussian, MoG, SwissRoll
from nde.GGC import GGC
from estimators.MRE import MRE
from estimators.NAE import NAE
from estimators.NAE4 import NAE4
from estimators.MRE2 import MRE2
from estimators.MINE import MINE
from estimators.MRE_flow import MRE_flow
from estimators.FLE import FLE
from estimators.InfoNCE import InfoNCE


# Hyperparams of MI estimators
class Hyperparams(object):
    def __init__(self): 
        self.critic = 'neural'                # ('neural', 'quadratic')
        self.lr = 5e-4
        self.bs = 500
        self.wd = 1e-5
        self.n_bridges = 4
        self.max_iteration = 1000

hyperparams=Hyperparams()

# Archtecture of MI estimators
architecture_critic = [np.inf, 500, 500, 500, 1]
architecture_encode = None

# MI estimators considered
# estimator_names = ['Adaptive-new']  # ['Multinomial', 'Adaptive', 'MINE', 'SMILE', 'DoE', 'InfoNCE']
# estimators = [NAE4]              # [MRE, MRE2, MINE, SMILE, FLE, InfoNCE]

# estimator_names = ['Adaptive']  
# estimators = [NAE]             

# estimator_names = ['MRE-flow']  
# estimators = [MRE_flow]     

estimator_names = ['Multinomial', 'Adaptive', 'MINE', 'DoE', 'InfoNCE']
estimators = [MRE, NAE, MINE, FLE, InfoNCE]


# Data saving directory
DIR = 'results/synthetic/data_n1000'
n_data = int(1000/0.8)



def _evaluate_core(fn, X, Y, ground_truth_MI, n_exps, device):
    n, d = torch.cat([X, Y], dim=1).size()

    # Create result file
    fn_full = fn+'.npy'
    if utils_os.is_file_exist(DIR, fn_full):
        print('result file already exist, will update file. \n')
        results = utils_os.load_object(DIR, fn_full).item()
    else:
        print('result file not exist, will create new file. \n')
        results = {}
        
    # # Neural density estimate
    # from nde import FMGGC, GGC
    # if True:
    #     gc = FMGGC(n_inputs=d//2)
    #     gc.to(device)
    # else:
    #     gc = GGC(n_blocks=2, n_inputs=d//2, n_hidden=500, n_cond_inputs=2)
    #     gc.to(device)
    #     gc.maf1.max_iteration = 200
    #     gc.maf2.max_iteration = 200
    #     gc.max_iteration = 100
    #     gc.bs = 250
    # gc.learn(X, Y)
    # with torch.no_grad():
    #     v, w = gc.forward(X, Y)
    v, w = None, None

    # MI estimate
    for i, name in enumerate(estimator_names):
        results[name] = []
        print('estimator:', name)
        # consider each configurations
        for _ in range(n_exps):
            estimator = estimators[i](
                            architecture_encoder_x=architecture_encode, 
                            architecture_encoder_y=architecture_encode, 
                            architecture_critic= [d] + architecture_critic[1:], 
                            hyperparams=hyperparams)
            estimator.to(device)
            # if name == 'Adaptive' or name == 'Adaptive2':
            #     estimator.set_nde(gc.normal, gc.normal2)
            #     estimator.learn(v, w)
            #     estimator.eval()
            #     MI_est = estimator.MI(v, w)
            # else:
            #     estimator.learn(X, Y)
            #     estimator.eval()
            #     MI_est = estimator.MI(X, Y)
            
            estimator.learn(X, Y)
            estimator.eval()
            MI_est = estimator.MI(X, Y)
            results[name].append(MI_est)
        print('estimator=', name, 'results=', results[name], '\n')
    results['Truth'] = ground_truth_MI
    # save result
    utils_os.save_object(DIR, fn, results)

    return utils_os.load_object(DIR, fn_full).item()




def evaluate_nonlinear_gaussian(case, dim, rho, n_exps=3, device='cuda:0'):
    print('case', case, 'dim', dim, 'rho', rho)

    # Dataset preparation
    n, d = n_data, dim               
    true_rho = rho
    case = case

    dataset = NonlinearGaussian.NonlinearGaussian(n_samples=n, n_dims=d, rho=true_rho, mu=0, case=case)
    X0, Y0 = dataset.sample_data(n_samples = n)
    X, Y = dataset.transformation(X0, Y0)
    X, Y = X.to(device), Y.to(device)

    # File name to save
    fn = f'nonlinear_gaussian_{case}_{dim}_{rho}'

    return _evaluate_core(X=X, Y=Y, fn=fn, ground_truth_MI=dataset.true_mutual_info(), n_exps=n_exps, device=device)




def evaluate_MoG(case, dim, n_exps=3, device='cuda:0'):
    print('case', case, 'dim', dim)

    # Dataset preparation
    n, d = n_data, dim               
    case = case

    shifts = [
        [-0.4, -0.1, 0, 0.1, 0.4],
        [-0.2, -0.1, 0, 0.3, 0.4]
    ]
    rhos = [
        [0.5, 0.6, 0.7, 0.8, 0.9],
        [-0.3, 0.5, 0.2, 0.4, 0.9]
    ]

    dataset = MoG.MoG(n_samples=n, n_dims=d, K=5, shifts=shifts[case], rhos=rhos[case])
    X, Y = dataset.sample_data(n_samples = n)
    X, Y = X.to(device), Y.to(device)

    # File name to save
    fn = f'MoG_{case}_{dim}'

    return _evaluate_core(X=X, Y=Y, fn=fn, ground_truth_MI=dataset.empirical_mutual_info(), n_exps=n_exps, device=device)




def evaluate_swiss_roll(dim, rho, n_exps=3, device='cuda:0'):
    print('case', 'default', 'dim', dim, 'rho', rho)

    # Dataset preparation
    n, d = n_data, dim     
    true_rho = rho          
    case = 'default'

    dataset = SwissRoll.SwissRoll(n_samples=n, n_dims=d, rho=true_rho, mu=0)
    X, Y = dataset.sample_data(n_samples = n)
    X, Y = X.to(device), Y.to(device)

    assert 1!=2, 'this part is unfinished.'

    d = d*2                        # <-- needs not to handle
    hyperparams=Hyperparams()
    hyperparams.n_bridges = 0      # <-- needs some good way to handle


    # File name to save
    fn = f'swissroll_{case}_{dim}_{rho}'
    return _evaluate_core(X=X, Y=Y, fn=fn, ground_truth_MI=dataset.empirical_mutual_info(), n_exps=n_exps, device=device)


        













