# -*- coding: utf-8 -*-

import sys
import os

# Import libraries, the code is built on PyTorch
import torch
import random
import warnings
import time

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset


warnings.filterwarnings("ignore")

from utils import *
from models import *
from estimators import *

# Check if CUDA is running
device = 'cuda'
if(not torch.cuda.is_available()):
    device = 'cpu'
print(device)

#############################
taskid = 0 # change if running on a cluster

tic_all = time.time()

        
# Define number of correlated components, number of kept dimensions, beta_lst for dvsib, number of trials
dz_lst = np.array([32]).astype(int)
beta_lst = np.array([512])
mi_lst = mi_schedule(20000)

n_trials = 1

est_lst = ['infonce', 'nwj', 'js', 'js_fgan', 'dv', 'tuba', 'smile', 'smile_5', 'mine_0', 'mine_0.5', 'mine_1', 'ialpha_0.5']

print('Starting Low Dim', flush=True)
###############################
# parameters
opt_params = {
    'epochs': 1,
    'n_iter': 20000,
    'batch_size': 128,
    'learning_rate': 5e-4,
}

data_params = {
    'samples': int(opt_params['n_iter']*opt_params['batch_size']/5),
    'Nx': 10,
    'Ny': 10,
    'ncom': 10,
}

# Unified critic parameters for DSIB and DVSIB (except beta)
base_critic_params = {
    'Nx': data_params['Nx'],
    'Ny': data_params['Ny'],
    'layers': 2,
    'embed_dim': int(dz_lst[0]),
    'hidden_dim': 256,
    'activation': 'leaky_relu',
}

# Model Types
model_types = ["concat"]


# Define a Teacher Model for X and Y
teacher_model_x = teacher(dz=data_params['ncom'], output_dim=data_params['Nx'])
teacher_model_y = teacher(dz=data_params['ncom'], output_dim=data_params['Ny'])

for param_x in teacher_model_x.parameters():
    param_x.requires_grad_(False)  # Freeze
for param_y in teacher_model_y.parameters():
    param_y.requires_grad_(False)
#############################
# Train function (add early stopping for the finite regime)
def train_model(model, dataset_type="teacher", model_type="dsib", epochs=20):
    """
    Generalized training function for DSIB and DVSIB models on teacher or nosiy datasets.

    Args:
        model: The model to train (DSIB or DVSIB).
        dataset_type: The dataset to use (teacher or raw or cubed).
        model_type: Either "dsib" (returns loss-- that is negative mi) or "dvsib" (returns loss, lossGin, lossGout-- that is mi).
        epochs: Number of training epochs.

    Returns:
        train_estimates containing mutual information estimates.
    """
    model.to(device)  # Ensure model is on GPU
    opt = torch.optim.Adam(model.parameters(), lr=opt_params['learning_rate'])

    estimates_mi_train = []

    for epoch in range(epochs):
        for i in range(opt_params['n_iter']):
            if dataset_type == "teacher":
                x, y = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=opt_params['batch_size'], info_tot=mi_lst[i], mlp_x=teacher_model_x, mlp_y=teacher_model_y, dev=True)
            elif dataset_type == "raw":
                x, y = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=opt_params['batch_size'], info_tot=mi_lst[i], dev=True)
            elif dataset_type == "cubed":
                x, y = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=opt_params['batch_size'], info_tot=mi_lst[i], dev=True)
                y = y**3
            else:
                raise ValueError("Invalid dataset_type. Choose 'teacher' or 'raw' or 'cubed'.")

            opt.zero_grad()
            
            # Compute loss based on model type
            if model_type == "dsib":
                mi = model(x, y)  # DSIB returns a single loss
                mi.backward()

            elif model_type == "dvsib":
                loss, _, mi = model(x, y)  # DVSIB returns three outputs
                loss.backward()
            else:
                raise ValueError("Invalid model_type. Choose 'dsib' or 'dvsib'.")

            opt.step()

            estimator_tr = mi.to('cpu').detach().numpy()

            if i%100==0:
                print(f"Epoch: {epoch+1}, Batch: {i}, MI: {mi_lst[i]}, {model_type}, train: {estimator_tr}", flush=True)

            estimates_mi_train.append(estimator_tr)

    return np.array(estimates_mi_train)


#############################
# Specify the folder name
folder_name = "Infinite_Data_Lowdim_All"

# Empty arrays and do the run
mi_direct = dict()
mi_dsib = {key: dict() for key in ["concat"]}

for j in range(n_trials):
    tic_trial = time.time()
    for est in est_lst:
        for model_type in model_types:
            print(f'{model_type.capitalize()} - DSIB - Trial: {j+1}, Estimator: {est}', flush=True)

            tic = time.time()
            torch.cuda.empty_cache()

            # Initialize and train DSIB
            dsib_model = DSIB(est, base_critic_params, model_type, None)
            print('Raw Dataset', flush=True)
            mis_dsib_raw = train_model(dsib_model, "raw", "dsib", opt_params['epochs'])

            torch.cuda.empty_cache()
            dsib_model = DSIB(est, base_critic_params, model_type, None)
            print('Cubed Dataset', flush=True)
            mis_dsib_cubed = train_model(dsib_model, "cubed", "dsib", opt_params['epochs'])

            torch.cuda.empty_cache()
            dsib_model = DSIB(est, base_critic_params, model_type, None)
            print('Teacher Dataset', flush=True)
            mis_dsib_teacher = train_model(dsib_model, "teacher", "dsib", opt_params['epochs'])

            # Store DSIB results
            mi_dsib[model_type][f"{j}_{est}_raw"] = mis_dsib_raw
            mi_dsib[model_type][f"{j}_{est}_cubed"] = mis_dsib_cubed
            mi_dsib[model_type][f"{j}_{est}_teacher"] = mis_dsib_teacher

            print(f'Time for {model_type.capitalize()} - DSIB = {round(time.time()-tic)} sec', flush=True)

            # Save DSIB results
            np.save(os.path.join(folder_name, f"{model_type.capitalize()}.npy"), mi_dsib[model_type])

    print('Time for trial: '+str(j+1)+' = '+str(round(time.time() - tic_trial,3))+' for all methods', flush=True)

# Direct Calculation
for j in range(n_trials):
    tic_trial = time.time()
    for mi in np.unique(mi_lst)[1:]:
        print(f'MI: {mi} - Direct - Trial: {j+1}', flush=True)

        tic = time.time()

        # Generate data
        train_X_raw, train_Y_raw = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=data_params['samples'], info_tot=mi)
        train_X_teacher, train_Y_teacher = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=data_params['samples'], info_tot=mi, mlp_x=teacher_model_x, mlp_y=teacher_model_y)

        train_X_raw, train_Y_raw = train_X_raw.numpy(), train_Y_raw.numpy()
        train_X_cubed, train_Y_cubed = train_X_raw, train_Y_raw**3
        train_X_teacher, train_Y_teacher = train_X_teacher.numpy(), train_Y_teacher.numpy()
        
        # Store Direct results
        mi_dsib[f"{j}_{mi}_raw"] = mut_info_opt(train_X_raw,train_Y_raw)
        mi_dsib[f"{j}_{mi}_cubed"] = mut_info_opt(train_X_cubed,train_Y_cubed)
        mi_dsib[f"{j}_{mi}_teacher"] = mut_info_opt(train_X_teacher,train_Y_teacher)

        print(f'Time for MI: {mi} - Direct = {round(time.time()-tic)} sec', flush=True)

        # Save DSIB results
        np.save(os.path.join(folder_name, f"Direct.npy"), mi_direct)


    print('Time for trial: '+str(j+1)+' = '+str(round(time.time() - tic_trial,3))+' for direct', flush=True)
    
print('Done!, Total time = '+str(round(time.time()-tic_all,3)), flush = True)