# -*- coding: utf-8 -*-

import sys
import os

# Import libraries, the code is built on PyTorch
import torch
import random
import warnings
import time

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset


warnings.filterwarnings("ignore")

from utils import *
from models import *
from estimators import *

# Check if CUDA is running
device = 'cuda'
if(not torch.cuda.is_available()):
    device = 'cpu'
print(device)

#############################
taskid = int(0) # 0-2 corresponds to the three conditions

tic_all = time.time()


        
# Define numbers
mi_lst = np.array([4])
dz_lst = np.array([2,4,8,10,32,128,512]).astype(int)
beta_lst = np.array([512])
samples_lst = np.logspace(5,16,12,base=2).astype(int)
ncom_lst = np.array([10, 100, 500]).astype(int)
n_trials = 10

ncom_ToUse = int(ncom_lst[taskid])
print(f'Starting ncom = {ncom_ToUse}', flush=True)

###############################
# parameters
opt_params = {
    'epochs': 100,
    'batch_size': 128,
    'learning_rate': 5e-4,
}

data_params = {
    'Nx': 500,
    'Ny': 500,
    'ncom': ncom_ToUse,
    'test_size': 128,
    'mi': mi_lst[0],
}


# Model Types
model_types = ["sep"]

# Define a Teacher Model for X and Y
teacher_model_x = teacher(dz=data_params['ncom'], output_dim=data_params['Nx'])
teacher_model_y = teacher(dz=data_params['ncom'], output_dim=data_params['Ny'])


for param_x in teacher_model_x.parameters():
    param_x.requires_grad_(False)  # Freeze
for param_y in teacher_model_y.parameters():
    param_y.requires_grad_(False)

#############################
# Train function with early stopping
def train_model(model, data, dataset_type="teacher", model_type="dsib", epochs=20, patience=50, min_delta=0.001):
    """
    Generalized training function for DSIB and DVSIB models on teacher or noisy datasets with early stopping.
    Args:
        model: The model to train (DSIB or DVSIB).
        data: The dataset to use (teacher or noisy).
        dataset_type: Either "teacher" or "noisy".
        model_type: Either "dsib" (returns loss, that's negative mi) or "dvsib" (returns loss, lossGin, lossGout).
        epochs: Number of training epochs.
        patience: Number of epochs to wait for improvement before stopping.
        min_delta: Minimum change to qualify as an improvement.
    Returns:
        A tuple (train_estimates, test_estimates) containing mutual information estimates.
    """
    model.to(device)  # Ensure model is on GPU
    opt = torch.optim.Adam(model.parameters(), lr=opt_params['learning_rate'])

    estimates_mi_train = []
    estimates_mi_test = []

    # Choose dataset variables and move to GPU once
    if dataset_type == "teacher":
        test_X, test_Y = test_X_teacher.to(device), test_Y_teacher.to(device)
        eval_X, eval_Y = eval_X_teacher.to(device), eval_Y_teacher.to(device)
    else:
        raise ValueError("Invalid dataset_type. Choose 'teacher' or 'noisy'.")

    # Early stopping variables
    best_estimator_ts = float('-inf')  # Initialize with negative infinity
    no_improvement_count = 0

    for epoch in range(epochs):        
        for i, (x, y) in enumerate(data):
            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            
            # Compute loss based on model type
            if model_type == "dsib":
                loss = model(x, y)  # DSIB returns a single loss
            elif model_type == "dvsib":
                loss, _, _ = model(x, y)  # DVSIB returns three outputs
            else:
                raise ValueError("Invalid model_type. Choose 'dsib' or 'dvsib'.")

            loss.backward()
            opt.step()

        # Evaluate the model at every epoch
        with torch.no_grad():
            if model_type == "dsib":
                estimator_tr = -model(eval_X, eval_Y)
                estimator_ts = -model(test_X, test_Y)
            elif model_type == "dvsib": # Get lossGout, that is the mi value
                _, _, estimator_tr = model(eval_X, eval_Y)
                _, _, estimator_ts = model(test_X, test_Y)

            estimator_tr = estimator_tr.to('cpu').detach().numpy()
            estimator_ts = estimator_ts.to('cpu').detach().numpy()
                
            estimates_mi_train.append(estimator_tr)
            estimates_mi_test.append(estimator_ts)

        print(f"Epoch: {epoch+1}, {model_type}, train_{dataset_type}: {estimator_tr}, test_{dataset_type}: {estimator_ts}", flush=True)

            
        # Check for improvement or negative values
        if estimator_ts < 0:
            no_improvement_count += 1
            # print(f"Epoch {epoch+1}: Negative estimator_ts detected ({avg_estimator_ts}). No improvement count: {no_improvement_count}/{patience}")
        elif estimator_ts > best_estimator_ts + min_delta:
            # We have an improvement
            best_estimator_ts = estimator_ts
            no_improvement_count = 0
            # print(f"Epoch {epoch+1}: Improvement detected! New best: {best_estimator_ts}")
        else:
            # No significant improvement
            no_improvement_count += 1
            # print(f"Epoch {epoch+1}: No significant improvement. Current: {avg_estimator_ts}, Best: {best_estimator_ts}. No improvement count: {no_improvement_count}/{patience}")
        
        # Check if we should stop early
        if no_improvement_count >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs. Best estimator_ts: {best_estimator_ts}")
            break

    return np.array(estimates_mi_train), np.array(estimates_mi_test)

#############################
# Specify the folder name
folder_name = "Finite_Data_summarizable"

# Empty arrays and do the run
mi_dsib = {key: dict() for key in ["sep"]}
mi_dsib_test = {key: dict() for key in ["sep"]}

# Test data stays constant during the training, eval is sampled from the train.
test_X_teacher, test_Y_teacher = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=data_params['test_size'], info_tot=data_params['mi'], mlp_x=teacher_model_x, mlp_y=teacher_model_y)

for sample in samples_lst:
    print('Samples =', sample, flush=True)
    # Make train datasets
    train_X_teacher, train_Y_teacher = sample_correlated_data(ncom=data_params['ncom'], total_size=data_params['Nx'], batch_size=sample, info_tot=data_params['mi'], mlp_x=teacher_model_x, mlp_y=teacher_model_y)
    eval_X_teacher,  eval_Y_teacher = train_X_teacher[:data_params['test_size']], train_Y_teacher[:data_params['test_size']]
    
    trainData_teacher=Dataset(train_X_teacher, train_Y_teacher)
    data_train_teacher = torch.utils.data.DataLoader(trainData_teacher, batch_size=opt_params['batch_size'],shuffle=True)

    for j in range(n_trials):
        tic_trial = time.time()    
        for est in ['infonce', 'smile_5']:
            model_type = model_types[0]
            for dz in dz_lst:
                print(f'Starting {dz}', flush=True)
                base_critic_params = {
                    'Nx': data_params['Nx'],
                    'Ny': data_params['Ny'],
                    'layers': 2,
                    'embed_dim': dz,
                    'hidden_dim': 256,
                    'activation': 'leaky_relu',
                }
                # print(f'{model_type.capitalize()} - DSIB - Trial: {j+1}, Estimator: {est}', flush=True)
    
                tic = time.time()    
                torch.cuda.empty_cache()
                dsib_model = DSIB(est, base_critic_params, model_type, None)
                mis_dsib_teacher, mis_dsib_test_teacher = train_model(dsib_model, data_train_teacher, "teacher", "dsib", opt_params['epochs'])
    
                # Store DSIB results
                mi_dsib[model_type][f"{ncom_ToUse}_{sample}_{dz}_{j}_{est}_teacher"] = mis_dsib_teacher
                mi_dsib_test[model_type][f"{ncom_ToUse}_{sample}_{dz}_{j}_{est}_teacher"] = mis_dsib_test_teacher
    
                print(f'Time for {model_type.capitalize()} - DSIB = {round(time.time()-tic)} sec', flush=True)
    
                # Save DSIB results
                np.save(os.path.join(folder_name, f"{model_type.capitalize()}_ncom{ncom_ToUse}.npy"), mi_dsib[model_type])
                np.save(os.path.join(folder_name, f"test_{model_type.capitalize()}_ncom{ncom_ToUse}.npy"), mi_dsib_test[model_type])
    

    print('Time for trial: '+str(j+1)+' = '+str(round(time.time() - tic_trial,3)), flush=True)
    
print('Done!, Total time = '+str(round(time.time()-tic_all,3)), flush = True)