import torchsde
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import torch.nn.functional as F
import numpy as np
import random

import timeit
from torch.nn import functional as F

import psutil
from geomloss import SamplesLoss
import logging
import os
import gc
from constants import *

if use_cuda == True:
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"



def print_memory_usage(device_id=0):
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(device_id).total_memory
        allocated_memory = torch.cuda.memory_allocated(device_id)
        cached_memory = torch.cuda.memory_reserved(device_id)
        logging.info(f"Total GPU Memory: {total_memory / (1024**3):.2f} GB")
        logging.info(f"Allocated Memory: {allocated_memory / (1024**3):.2f} GB")
        logging.info(f"Cached Memory: {cached_memory / (1024**3):.2f} GB")
    else:
        memory = psutil.virtual_memory()
        logging.info(f"Total memory: {memory.total / (1024**3):.2f} GB")
        logging.info(f"Available memory: {memory.available / (1024**3):.2f} GB")
        logging.info(f"Used memory: {memory.used / (1024**3):.2f} GB")
        logging.info(f"Memory usage: {memory.percent}%")


def run_one_epoch(loss_fct, train_flag, initial_dist_tensor, t, end_dist_tensor, ode_func_train, optimizer, L1 = False, group_lasso = False, sparsity_coefficient = 0.001):

    torch.set_grad_enabled(train_flag)
    ode_func_train.train() if train_flag else ode_func_train.eval()

    output = torchsde.sdeint(ode_func_train, initial_dist_tensor, t, method='euler')
    
    output = output.squeeze() # remove spurious channel dimension

    if output.dim() < 3:
        #ODE integration went wrong
       return torch.tensor(-1)
    

    prediction = output[-1,:,:]
    
    if L1 and not group_lasso: 
        L1_penalty = torch.sum(torch.abs((ode_func_train.B_0).flatten())) 
        loss = loss_fct(end_dist_tensor,prediction) + sparsity_coefficient * L1_penalty
    elif group_lasso and not L1: 
        group_lasso_penalty = torch.sum(torch.norm(ode_func_train.B_0, dim=0)) 
        loss = loss_fct(end_dist_tensor,prediction) + sparsity_coefficient * group_lasso_penalty
    elif group_lasso and L1: 
        group_lasso_penalty = torch.sum(torch.norm(ode_func_train.B_0, dim=0)) 
        L1_penalty = torch.sum(torch.abs((ode_func_train.B_1).flatten())) 
        joint_penalty = sparsity_coefficient * group_lasso_penalty + sparsity_coefficient * L1_penalty
        loss = loss_fct(end_dist_tensor,prediction) + joint_penalty
    else:
        loss = loss_fct(end_dist_tensor,prediction) 

    if train_flag:
        loss.backward(retain_graph=True) # back propagation
        optimizer.step()
        optimizer.zero_grad()

    return loss



def train_boolode(new_folder_path, interv_dict, initial_dist, brownian_dict, train_dict, validation_dict, ode_func_train_W2, optimizer, t_flow = 1., t_span = 50.,max_epoch = 200.,patience = 10., max_time = 23., knock_out = False, L1 = False, perfect_interv = False):
    #Get the top 20 Train and Validation TFs with the most cells/samples
    #This is for loss plot only
    num_cells_dict = {}
    for tf in validation_dict:
        num_cells_dict[tf] = len(validation_dict[tf])
    
    sorted_items_val = sorted(num_cells_dict.items(), key=lambda x: x[1], reverse=True)
    top_keys_val = [item[0] for item in sorted_items_val[:20]]
    val_graph_dict = {}
    train_graph_dict = {}
    for tf in top_keys_val:
        val_graph_dict[tf] = []
        train_graph_dict[tf] = []


    loss_fct = SamplesLoss(loss="sinkhorn", p=2, blur=.05) 

    #Get the top 20 Train and Validation TFs with the most cells/samples
    #This is for loss plot only
    num_cells_dict = {}
    for tf in validation_dict:
        num_cells_dict[tf] = len(validation_dict[tf])

    sorted_items_val = sorted(num_cells_dict.items(), key=lambda x: x[1], reverse=True)
    top_keys_val = [item[0] for item in sorted_items_val[:20]]
    val_graph_dict = {}
    train_graph_dict = {}
    for tf in top_keys_val:
        val_graph_dict[tf] = []
        train_graph_dict[tf] = []


    random_key = random.choice(list(brownian_dict.keys()))

    zero_tensor = torch.zeros(brownian_dict[random_key].shape[1]).to(device).type(dtype)
    one_tensor = torch.ones(brownian_dict[random_key].shape[1]).to(device).type(dtype)

    t_diffused = torch.linspace(0., t_flow, t_span).to(device).type(dtype)
    T = torch.linspace(0., 1, t_span).to(device).type(dtype)

    initial_dist_tensor = torch.tensor(initial_dist.X).to(device).type(dtype)
    
    train_tensor_dic = {}

    for tf in train_dict: 
        target_dist = train_dict[tf]
        target_dist_tensor = torch.tensor(target_dist.X).to(device).type(dtype)
        train_tensor_dic[tf] = target_dist_tensor

    validation_tensor_dic = {}
    for tf in validation_dict:
        validation_dist = validation_dict[tf]
        validation_dist_tensor = torch.tensor(validation_dist.X).to(device).type(dtype)
        validation_tensor_dic[tf] = validation_dist_tensor



    ode_func_train_W2.train()

    start_time_total = timeit.default_timer()

    best_val_loss = float('inf')
    epochs_without_improvement = 0

    initial_dist_dict = {}
    for tf in brownian_dict:
        sampled_initial_dist = brownian_dict[tf].X
        sampled_initial_dist_tensor = torch.tensor(sampled_initial_dist).to(device).type(dtype)
        initial_dist_dict[tf] = sampled_initial_dist_tensor

    val_loss_lst = []
    
    for epoch in range(max_epoch):
        start_time_epoch = timeit.default_timer()
        #Training
        for tf in train_tensor_dic:
            start_time_tf = timeit.default_timer()
            target_dist = train_dict[tf]
            target_dist_tensor = train_tensor_dic[tf]

            for gene in interv_dict[tf]:
                gene_index = gene

                if knock_out:
                    ode_func_train_W2.f_[:] = zero_tensor
                    ode_func_train_W2.f_[gene_index] = -1
                else:
                    ode_func_train_W2.f_[:] = zero_tensor
                    ode_func_train_W2.f_[gene_index] = 1

                ode_func_train_W2.intervention[:] = one_tensor
                if perfect_interv:
                    ode_func_train_W2.intervention[gene_index] = 0

            train_loss = run_one_epoch(loss_fct, True, initial_dist_tensor, T, target_dist_tensor, ode_func_train_W2, optimizer, L1 = L1)
            if tf in top_keys_val:
                train_graph_dict[tf].append(train_loss.cpu().detach().numpy())
            elapsed_tf = float(timeit.default_timer() - start_time_tf)
            torch.cuda.empty_cache()
            gc.collect()

            logging.info(str(tf) + " took %.2fs. Train loss: %.4f." %
                (elapsed_tf, train_loss))
        
            
            
        #Validation 
        ode_func_train_W2.eval()
        val_loss_lst.clear()
        for tf in validation_tensor_dic:
            for gene in interv_dict[tf]:
                gene_index = gene

                if knock_out:
                    ode_func_train_W2.f_[:] = zero_tensor
                    ode_func_train_W2.f_[gene_index] = -1
                else:
                    ode_func_train_W2.f_[:] = zero_tensor
                    ode_func_train_W2.f_[gene_index] = 1

                ode_func_train_W2.intervention[:] = one_tensor
                if perfect_interv:
                    ode_func_train_W2.intervention[gene_index] = 0

            validation_dist_tensor = validation_tensor_dic[tf]
            val_loss = run_one_epoch(loss_fct, False, initial_dist_tensor, T, validation_dist_tensor, ode_func_train_W2, optimizer, L1 = L1)
            val_loss_lst.append(val_loss.cpu().detach().numpy())
            logging.info(str(tf) + " Validation loss: %.4f." %
                (val_loss))

            if tf in top_keys_val:
                val_graph_dict[tf].append(val_loss.cpu().detach().numpy())
 
        #Take the average of the Validation losses
        mean_val_loss = np.mean(val_loss_lst)
        ode_func_train_W2.train()


        elapsed_epoch = float(timeit.default_timer() - start_time_epoch)
        logging.info("Epoch %i took %.2fs. Validation loss: %.4f." %
                (epoch+1, elapsed_epoch, mean_val_loss))
        logging.info('--------------------------------------------------------------------------------------------------------')

        torch.cuda.empty_cache()
        gc.collect()

        print_memory_usage()

        if mean_val_loss < best_val_loss:
                best_val_loss = mean_val_loss
                epochs_without_improvement = 0
                best_params = ode_func_train_W2.state_dict() 
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                logging.info(f'Validation loss did not improve for {patience} consecutive epochs. Stopping training.')
                break
        
        plt.figure(figsize=(10, 6))

        for tf, losses in val_graph_dict.items():
            epochs = range(0, len(losses))
            plt.plot(epochs, losses, label= str(tf) + ' validation', linestyle='--')

        for tf, losses in train_graph_dict.items():
            epochs = range(0, len(losses))
            plt.plot(epochs, losses, label= str(tf) + ' train')
        
        validation_graph_output_path = os.path.join(new_folder_path, 'brownian_revolver_validation_losses.png')
        
        plt.title('Brownian Train/Validation Loss by Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
        plt.subplots_adjust(right=0.75)
        plt.grid(True)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.savefig(validation_graph_output_path, format='png')


        #Time Out
        
        elapsed_total = float(timeit.default_timer() - start_time_total)
        if elapsed_total > 3600 * max_time:
            logging.info("Training ended for exceeding max time.")
            break

    ode_func_train_W2.load_state_dict(best_params)   

