import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.stats import gaussian_kde
from sklearn.preprocessing import LabelEncoder
import json
import os

from copy import copy

from .visualize import create_KDE_scatter_plots
#from .data import reduce_dimensions
from .config import config
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use('Agg')  # Use non-interactive backend


class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean(
            (1 - label) * torch.pow(euclidean_distance, 2) +
            label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
        )
        return loss_contrastive
    
    
# Analysis for the at val epoch end and at test end
class AnalysisModule:
    def __init__(self):
        self.reset_batch_data_for_analysis()

    def save_batch_data_for_analysis(self, linear_output, gt):
        # Convert torch tensors to numpy arrays
        linear_output_np = linear_output.detach().cpu().numpy()
        gt_np = gt.detach().cpu().numpy()

        # Check if saved_batch_values exists, if not create it
        if not hasattr(self, 'saved_batch_values'):
            self.reset_batch_data_for_analysis()

        # Stack the numpy arrays
        if self.saved_batch_values['linear_output'] is None:
            self.saved_batch_values['linear_output'] = linear_output_np
        else:
            self.saved_batch_values['linear_output'] = np.vstack((self.saved_batch_values['linear_output'], linear_output_np))

        if self.saved_batch_values['gt'] is None:
            self.saved_batch_values['gt'] = gt_np
        else:
            self.saved_batch_values['gt'] = np.hstack((self.saved_batch_values['gt'], gt_np))
        
    def reset_batch_data_for_analysis(self):
        # Initialize saved_batch_values as a dictionary with None values
        self.saved_batch_values = {
            'linear_output': None,
            'gt': None
        }

    def get_batch_data_for_analysis(self):
        return self.saved_batch_values
    
    # Information crunching for the KDE
    # Calculate the information gain
    def calculate_eval_metrics(self, balanced = False):
        # Getting the data
        if balanced:
            saved_values = copy(self.return_balanced_data())
        else:
            saved_values = copy(self.saved_batch_values)
        
        values = {
            'kde': [],
            'knn': []
        }
        
        for i in range(saved_values['linear_output'].shape[1]):
            curr_embeddings = saved_values['linear_output'][:, i, :]
            
            curr_embeddings = reduce_dimensions(curr_embeddings, n_components=2, method = config.get('logging', 'dim_reduction_logging'))
            
            values["kde"].append(calculate_KDE_entropy(curr_embeddings, saved_values['gt']))
            values["knn"].append(calculate_KNN(curr_embeddings, saved_values['gt']))
        
        return values
        
    def create_scatter_plots(self, balanced = False):
        if balanced:
            saved_values = copy(self.return_balanced_data())
        else:
            saved_values = copy(self.saved_batch_values)
         
        fig, axes = plt.subplots(5, 7, figsize=(20, 15))
        axes = axes.flatten()
        
        for i in range(saved_values['linear_output'].shape[1]):
            curr_embeddings = saved_values['linear_output'][:, i, :]
            
            curr_embeddings = reduce_dimensions(curr_embeddings, n_components=2, method = config.get('logging', 'dim_reduction_logging'))
            
            create_KDE_scatter_plots(axes[i], curr_embeddings, saved_values['gt'])
            axes[i].set_title(f'Layer {i + 1}')
            axes[i].set_xticks([])
            axes[i].set_yticks([])
        
        # Remove any unused subplots
        for j in range(i + 1, len(axes)):
            fig.delaxes(axes[j])
        
        plt.tight_layout()
        plt.show()
        
        return fig, axes
    
    def return_balanced_data(self):
        # Get the unique labels
        unique_labels = np.unique(self.saved_batch_values['gt'])
        
        # Get the number of samples in each class
        class_counts = {label: np.sum(self.saved_batch_values['gt'] == label) for label in unique_labels}
        
        # Find the class with the least samples
        min_class = min(class_counts, key=class_counts.get)
        
        # Get the number of samples in the smallest class
        min_class_count = class_counts[min_class]
        
        # Get the indices of the samples in the smallest class
        min_class_indices = np.where(self.saved_batch_values['gt'] == min_class)[0]
        
        # Randomly sample the same number of samples from each class
        balanced_indices = []
        
        for label in unique_labels:
            label_indices = np.where(self.saved_batch_values['gt'] == label)[0]
            balanced_indices.extend(np.random.choice(label_indices, min_class_count, replace=False))
            
        # Get the balanced data
        balanced_data = {
            'linear_output': self.saved_batch_values['linear_output'][balanced_indices],
            'gt': self.saved_batch_values['gt'][balanced_indices]
        }
        
        return balanced_data
    
# Calculating the information gain within the model
def calculate_KDE_entropy(embeddings, labels):
    # Ensure labels are encoded
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)
    unique_labels = np.unique(encoded_labels)
    
    entropies = []
    
    for label in unique_labels:
        # Select embeddings for the current label
        class_embeddings = embeddings[encoded_labels == label]
        
        # Calculate KDE
        kde = gaussian_kde(class_embeddings.T)  # Transpose to match the expected shape
        
        # Evaluate the KDE on the same points to get the PDF
        pdf = kde(class_embeddings.T)
        
        # Compute entropy
        entropy = -np.sum(pdf * np.log(pdf + 1e-10))  # Add epsilon to avoid log(0)
        entropies.append(entropy)
    
    # Average the entropies of each class
    total_entropy = np.mean(entropies)
    
    return total_entropy

# Returns the KNN class percentage of points that are in the same class
def calculate_KNN(embeddings, labels, k=5):
    # Ensure labels are encoded
    label_encoder = LabelEncoder()
    encoded_labels = label_encoder.fit_transform(labels)
    
    n = embeddings.shape[0]
    knn_percentages = []
    
    # Calculate the distance matrix
    distances = np.linalg.norm(embeddings[:, None] - embeddings, axis=-1)
    
    for i in range(n):
        # Get the indices of the k nearest neighbors, excluding the point itself
        nearest_indices = np.argsort(distances[i])[1:k+1]
        
        # Count how many of the nearest neighbors have the same label
        same_class_count = np.sum(encoded_labels[nearest_indices] == encoded_labels[i])
        
        # Calculate the percentage of nearest neighbors that belong to the same class
        same_class_percentage = same_class_count / k
        
        knn_percentages.append(same_class_percentage)
    
    # Calculate the average percentage across all points
    average_knn_percentage = np.mean(knn_percentages)
    
    return average_knn_percentage


class EmbeddingKVStore:
    def __init__(self, filename):
        filepath = os.path.join(config.get('dataset', 'nifty_path'), filename + "_bank.json")
        
        self.filepath = filepath
        self.store = self._load_store()

    def _load_store(self):
        if os.path.exists(self.filepath):
            with open(self.filepath, 'r') as file:
                return json.load(file)
        else:
            return {
                "data": [],
                "embeddings": {}
            }

    def _save_store(self):
        with open(self.filepath, 'w') as file:
            json.dump(self.store, file)

    def get(self, key, default=None):
        return self.store.get(key, default)
        
    def save(self, save_value):
        key = save_value['prompt']
        
        if key not in self.store["embeddings"]:
            self.store["embeddings"][key] = copy(save_value)
            save_value.pop('embedding')
            self.store["data"].append(save_value)
            self._save_store()
            return save_value
        else:
            return self.store["embeddings"][key]