# -*- coding: utf-8 -*-

import sys
import os

# Import libraries, the code is built on PyTorch
import torch
import random
import warnings
import time

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset


warnings.filterwarnings("ignore")

from utils import *
from models import *
from estimators import *

# Check if CUDA is running
device = 'cuda'
if(not torch.cuda.is_available()):
    device = 'cpu'
print(device)

#############################
taskid = int(0) # pick a gamma

tic_all = time.time()


        
# Define numbers
dz_lst = np.array([4,8,10,32,128]).astype(int)
beta_lst = np.array([512])

total_samples = 2**16

# Calculate how many chunks we have for this sample size
num_chunks = taskid+1
sample_size = total_samples // num_chunks

print(f"Running taskid {taskid} with sample size {sample_size}, processing {num_chunks} chunks")


###############################
# parameters
opt_params = {
    'epochs': 100,
    'batch_size': 128,
    'learning_rate': 5e-4,
}

data_params = {
    'Nx': 784,
    'Ny': 784,
    'test_size': 128,
}


# Model Types
model_types = ["sep"]

#############################
# Train function with early stopping
def train_model(model, data, model_type="dsib", epochs=20, patience=50, min_delta=0.001):
    """
    Generalized training function for DSIB and DVSIB models on teacher or noisy datasets with early stopping.
    Args:
        model: The model to train (DSIB or DVSIB).
        data: The dataset to use.
        model_type: Either "dsib" (returns loss, that's negative mi) or "dvsib" (returns loss, lossGin, lossGout).
        epochs: Number of training epochs.
        patience: Number of epochs to wait for improvement before stopping.
        min_delta: Minimum change to qualify as an improvement.
    Returns:
        A tuple (train_estimates, test_estimates) containing mutual information estimates.
    """
    model.to(device)  # Ensure model is on GPU
    opt = torch.optim.Adam(model.parameters(), lr=opt_params['learning_rate'])

    estimates_mi_train = []
    estimates_mi_test = []

    # Early stopping variables
    best_estimator_ts = float('-inf')  # Initialize with negative infinity
    no_improvement_count = 0

    for epoch in range(epochs):        
        for i, (x, y) in enumerate(data):
            x, y = x.to(device), y.to(device)

            opt.zero_grad()
            
            # Compute loss based on model type
            if model_type == "dsib":
                loss = model(x, y)  # DSIB returns a single loss
            elif model_type == "dvsib":
                loss, _, _ = model(x, y)  # DVSIB returns three outputs
            else:
                raise ValueError("Invalid model_type. Choose 'dsib' or 'dvsib'.")

            loss.backward()
            opt.step()

        # Evaluate the model at every epoch
        with torch.no_grad():
            if model_type == "dsib":
                estimator_tr = -model(eval_X, eval_Y)
                estimator_ts = -model(test_X, test_Y)
            elif model_type == "dvsib": # Get lossGout, that is the mi value
                _, _, estimator_tr = model(eval_X, eval_Y)
                _, _, estimator_ts = model(test_X, test_Y)

            estimator_tr = estimator_tr.to('cpu').detach().numpy()
            estimator_ts = estimator_ts.to('cpu').detach().numpy()
                
            estimates_mi_train.append(estimator_tr)
            estimates_mi_test.append(estimator_ts)

        print(f"Epoch: {epoch+1}, {model_type}, train: {estimator_tr}, test: {estimator_ts}", flush=True)

            
        # Check for improvement or negative values
        if estimator_ts < 0:
            no_improvement_count += 1
            # print(f"Epoch {epoch+1}: Negative estimator_ts detected ({avg_estimator_ts}). No improvement count: {no_improvement_count}/{patience}")
        elif estimator_ts > best_estimator_ts + min_delta:
            # We have an improvement
            best_estimator_ts = estimator_ts
            no_improvement_count = 0
            # print(f"Epoch {epoch+1}: Improvement detected! New best: {best_estimator_ts}")
        else:
            # No significant improvement
            no_improvement_count += 1
            # print(f"Epoch {epoch+1}: No significant improvement. Current: {avg_estimator_ts}, Best: {best_estimator_ts}. No improvement count: {no_improvement_count}/{patience}")
        
        # Check if we should stop early
        if no_improvement_count >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs. Best estimator_ts: {best_estimator_ts}")
            break

    return np.array(estimates_mi_train), np.array(estimates_mi_test)


#############################
# Specify the folder name
folder_name = "MNIST_Subsampling"

# Empty arrays and do the run
mi_dsib = {key: dict() for key in ["sep"]}
mi_dsib_test = {key: dict() for key in ["sep"]}


# Load the paired datasets
paired_train = torch.load('paired_augmented_mnist_train.pt', weights_only=True)
paired_test = torch.load('paired_augmented_mnist_test.pt', weights_only=True)

# Test data stays constant during the training, eval is sampled from the train.
test_X, test_Y = paired_test[0,:,:], paired_test[1,:,:]

# Move to GPU once
test_X, test_Y = test_X.to(device), test_Y.to(device)

# Process all chunks for this sample size
for chunk in range(num_chunks):
    start_idx = chunk * sample_size
    end_idx = start_idx + sample_size
    # Ensure we don't go out of bounds on last chunk (in case of integer division loss)
    if chunk == num_chunks - 1:
        end_idx = total_samples  # Make sure we include all remaining data
    
    print(f"Processing chunk {chunk+1}/{num_chunks}, range: {start_idx}:{end_idx}")
    
    # Extract the current chunk of data
    train_X, train_Y = paired_train[0, start_idx:end_idx, :], paired_train[1, start_idx:end_idx, :]
    eval_X, eval_Y = train_X[:data_params['test_size'],:], train_Y[:data_params['test_size'],:]
    eval_X, eval_Y = eval_X.to(device), eval_Y.to(device)
    trainData = Dataset(train_X, train_Y)
    data_train = torch.utils.data.DataLoader(trainData, batch_size=opt_params['batch_size'], shuffle=True)
    
    # For saving results, use chunk information as part of the identifier
    chunk_id = f"size{sample_size}_chunk{chunk}"
    
    tic_trial = time.time()    
    for est in ['infonce']:
        model_type = model_types[0]
        for dz in dz_lst:
            print(f'Starting {dz}', flush=True)
            base_critic_params = {
                'Nx': data_params['Nx'],
                'Ny': data_params['Ny'],
                'layers': 4,
                'embed_dim': dz,
                'hidden_dim': 512,
                'activation': 'leaky_relu',
            }
            print(f'{model_type.capitalize()} - DSIB, Estimator: {est}', flush=True)

            tic = time.time()    
            torch.cuda.empty_cache()
            dsib_model = DSIB(est, base_critic_params, model_type, None)
            mis_dsib, mis_dsib_test = train_model(dsib_model, data_train, "dsib", opt_params['epochs'])

            # Store DSIB results
            mi_dsib[model_type][f"{chunk_id}_{dz}_{est}"] = mis_dsib
            mi_dsib_test[model_type][f"{chunk_id}_{dz}_{est}"] = mis_dsib_test

            print(f'Time for {model_type.capitalize()} - DSIB = {round(time.time()-tic)} sec', flush=True)

            # Save DSIB results
            np.save(os.path.join(folder_name, f"{model_type.capitalize()}_sample{sample_size}.npy"), mi_dsib[model_type])
            np.save(os.path.join(folder_name, f"test_{model_type.capitalize()}_sample{sample_size}.npy"), mi_dsib_test[model_type])
    
    print('Time for trial = '+str(round(time.time() - tic_trial,3)), flush=True)
    
print('Done!, Total time = '+str(round(time.time()-tic_all,3)), flush = True)