# -*- coding: utf-8 -*-

import sys
import os

# Import libraries, the code is built on PyTorch
import torch
import random
import warnings
import time

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset


warnings.filterwarnings("ignore")

from utils import *
from models import *
from estimators import *

# Check if CUDA is running
device = 'cuda'
if(not torch.cuda.is_available()):
    device = 'cpu'
print(device)

#############################
taskid = int(0) # Change if running on a cluster

tic_all = time.time()


        
# Define number of correlated components, number of kept dimensions, beta_lst for dvsib, number of trials
dz_lst = np.array([32]).astype(int)
beta_lst = np.array([512])
mi_lst = mi_schedule(20000)
n_trials = 1


print('Starting High Dim', flush=True)
###############################
# parameters

opt_params = {
    'epochs': 1,
    'n_iter': 20000,
    'batch_size': 128,
    'learning_rate': 5e-4,

}

data_params = {
    'Nx': 500,
    'Ny': 500,
    'ncom': 10,
}

# Unified critic parameters for DSIB and DVSIB (except beta)
base_critic_params = {
    'Nx': data_params['Nx'],
    'Ny': data_params['Ny'],
    'layers': 2,
    'embed_dim': int(dz_lst[0]),
    'hidden_dim': 256,
    'activation': 'leaky_relu',
}

# Model Types
model_types = ["concat"]

# Define a Teacher Model for X and Y
teacher_model_x = teacher(dz=data_params['ncom'], output_dim=data_params['Nx'])
teacher_model_y = teacher(dz=data_params['ncom'], output_dim=data_params['Ny'])


for param_x in teacher_model_x.parameters():
    param_x.requires_grad_(False)  # Freeze
for param_y in teacher_model_y.parameters():
    param_y.requires_grad_(False)


#############################
# Train function (add early stopping for the finite regime)
def train_model(model, dataset_type="teacher", model_type="dsib", epochs=20):
    """
    Generalized training function for DSIB and DVSIB models on teacher or nosiy datasets.

    Args:
        model: The model to train (DSIB or DVSIB).
        dataset_type: The dataset to use (teacher or noisy).
        model_type: Either "dsib" (returns loss-- that is negative mi) or "dvsib" (returns loss, lossGin, lossGout-- that is mi).
        epochs: Number of training epochs.

    Returns:
        train_estimates containing mutual information estimates.
    """
    model.to(device)  # Ensure model is on GPU
    opt = torch.optim.Adam(model.parameters(), lr=opt_params['learning_rate'])

    estimates_mi_train = []

    for epoch in range(epochs):
        for i in range(opt_params['n_iter']):
            if dataset_type == "teacher":
                x, y = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=opt_params['batch_size'], info_tot=mi_lst[i], mlp_x=teacher_model_x, mlp_y=teacher_model_y, dev=True)
            elif dataset_type == "noisy":
                x, y = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=opt_params['batch_size'], info_tot=mi_lst[i], mlp_x=teacher_model_prob_x, mlp_y=teacher_model_prob_y)
            else:
                raise ValueError("Invalid dataset_type. Choose 'teacher' or 'noisy'.")

            opt.zero_grad()
            
            # Compute loss based on model type
            if model_type == "dsib":
                mi = model(x, y)  # DSIB returns a single loss
                mi.backward()

            elif model_type == "dvsib":
                loss, _, mi = model(x, y)  # DVSIB returns three outputs
                loss.backward()
            else:
                raise ValueError("Invalid model_type. Choose 'dsib' or 'dvsib'.")

            opt.step()

            estimator_tr = mi.to('cpu').detach().numpy()

            if i%100==0:
                print(f"Epoch: {epoch+1}, Batch: {i}, MI: {mi_lst[i]}, {model_type}, train: {estimator_tr}", flush=True)

            estimates_mi_train.append(estimator_tr)

    return np.array(estimates_mi_train)


#############################
# Specify the folder name
folder_name = "Infinite_Data_Highdim"

# Empty arrays and do the run
mi_dsib = {key: dict() for key in ["concat"]}

mi_dvsib = {key: dict() for key in ["concat"]}


for j in range(n_trials):
    tic_trial = time.time()
    for est in ['infonce', 'smile_5']:
        for model_type in model_types:
            print(f'{model_type.capitalize()} - DSIB - Trial: {j+1}, Estimator: {est}', flush=True)

            tic = time.time()
            torch.cuda.empty_cache()

            # Initialize and train DSIB
            torch.cuda.empty_cache()
            dsib_model = DSIB(est, base_critic_params, model_type, None)
            print('Teacher Dataset', flush=True)
            mis_dsib_teacher = train_model(dsib_model, "teacher", "dsib", opt_params['epochs'])

            # Store DSIB results
            mi_dsib[model_type][f"{j}_{est}_teacher"] = mis_dsib_teacher

            print(f'Time for {model_type.capitalize()} - DSIB = {round(time.time()-tic)} sec', flush=True)

            # Save DSIB results
            np.save(os.path.join(folder_name, f"{model_type.capitalize()}.npy"), mi_dsib[model_type])

            # ---- DVSIB Training ----
            beta = beta_lst[0]
            print(f'{model_type.capitalize()} - DVSIB - Trial: {j+1}, Estimator: {est}, Beta: {beta}', flush=True)

            tic = time.time()
            torch.cuda.empty_cache()

            # Define critic_params for DVSIB with beta
            critic_params_dvsib = {**base_critic_params, 'beta': beta}

            # Initialize and train DVSIB
            dvsib_model = DVSIB(est, critic_params_dvsib, model_type, None)
            print('Teacher Dataset', flush=True)
            mis_dvsib_teacher = train_model(dvsib_model, "teacher", "dvsib", opt_params['epochs'])

            # Store DVSIB results
            mi_dvsib[model_type][f"{j}_{est}_{beta}_teacher"] = mis_dvsib_teacher

            print(f'Time for {model_type.capitalize()} - DVSIB - Beta {beta} = {round(time.time()-tic)} sec', flush=True)

            # Save DVSIB results
            dvsib_file = f"{model_type.capitalize()}_DVSIB_beta{beta}.npy"
            np.save(os.path.join(folder_name, dvsib_file), mi_dvsib[model_type])

    print('Time for trial: '+str(j+1)+' = '+str(round(time.time() - tic_trial,3)), flush=True)
    
print('Done!, Total time = '+str(round(time.time()-tic_all,3)), flush = True)