import pytorch_lightning as pl
import torch.nn as nn
import wandb

import transformers
import sys
import os
import torch

from .config import config
from .util import ContrastiveLoss, AnalysisModule, EmbeddingKVStore

import torch
import torch.nn as nn
import pytorch_lightning as pl

### Nonlinearity Models ###

class FNNNonlinearityNetwork(nn.Module):
    def __init__(self, n_nonlinearity_layers, llm_hidden_dim_size, nonlinearity_hidden_dim_size, dropout_prob):
        super(FNNNonlinearityNetwork, self).__init__()
        layers = []
        
        # Add the specified number of layers with ReLU, BatchNorm, and Dropout
        for layer_num in range(n_nonlinearity_layers):
            if layer_num == 0:
                layers.append(nn.Linear(llm_hidden_dim_size, nonlinearity_hidden_dim_size))
            else:
                layers.append(nn.Linear(nonlinearity_hidden_dim_size, nonlinearity_hidden_dim_size))
                
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(nonlinearity_hidden_dim_size))
            layers.append(nn.Dropout(dropout_prob))
        
        # TEMP: Bring down to 2 dimensions
        layers.append(nn.Linear(nonlinearity_hidden_dim_size, 16))
        
        # Define the network as a sequential container
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)
    
class AttentionNonlinearityNetwork(nn.Module):
    def __init__(self, n_nonlinearity_layers, llm_hidden_dim_size, nonlinearity_hidden_dim_size, dropout_prob, num_heads=1):
        super(AttentionNonlinearityNetwork, self).__init__()
        
        # Embedding layers to match the dimensions
        self.embedding = nn.Linear(llm_hidden_dim_size, nonlinearity_hidden_dim_size)
        
        # Transformer Encoder Layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=nonlinearity_hidden_dim_size, 
            nhead=num_heads, 
            dim_feedforward=nonlinearity_hidden_dim_size, 
            dropout=dropout_prob
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_nonlinearity_layers)
        
        # Downsample to desired dimension
        self.downsample = nn.Linear(nonlinearity_hidden_dim_size, 16)
    
    def forward(self, x):
        # Embed input to the size expected by the transformer
        x = self.embedding(x)
        
        # Assuming input is [batch_size, seq_length, feature_dim], and transformer expects [seq_length, batch_size, feature_dim]
        x = x.permute(1, 0, 2)
        
        # Pass through the transformer encoder layers
        x = self.transformer_encoder(x)
        
        # Permute back to [batch_size, seq_length, feature_dim]
        x = x.permute(1, 0, 2)
        
        # Pass through the downsample layer (applied to the last dimension)
        x = self.downsample(x)
        
        return x

class NonlinearityModel(pl.LightningModule):
    def __init__(self):
        super(NonlinearityModel, self).__init__()
        self.nn = self.create_nonlinearity_model()
        
        self.criterion = ContrastiveLoss()
        
        self.train_end_epoch_analysis = AnalysisModule()
        self.val_end_epoch_analysis = AnalysisModule()
        
        self.epoch = 0

    def create_nonlinearity_model(self):
        n_nonlinearity_layers = config.get('model', 'n_nonlinearity_layers')
        n_llm_layers = config.get('model', 'n_llm_layers')
        llm_hidden_dim_size = config.get('model', 'llm_hidden_dim_size')
        nonlinearity_hidden_dim_size = config.get('model', 'nonlinearity_hidden_dim_size')
        dropout_prob = config.get('model', 'dropout_prob')
        
        # Determine the type of nonlinearity network to create
        if config.get('model', 'nonlinearity_type') == "fnn":
            NonlinearityNetwork = FNNNonlinearityNetwork
        elif config.get('model', 'nonlinearity_type') == "attention":
            NonlinearityNetwork = AttentionNonlinearityNetwork
        
        # Create a list to hold the separate networks
        networks = nn.ModuleList([
            AttentionNonlinearityNetwork(n_nonlinearity_layers, llm_hidden_dim_size, nonlinearity_hidden_dim_size, dropout_prob)
            for _ in range(n_llm_layers)
        ])
        
        return networks

    def forward(self, x : torch.Tensor):
        assert x.dim() == 3
        
        outputs = []
        
        # Pass the input through each network
        for network, l_number in zip(self.nn, range(len(self.nn))):
            current_layer_embeddings = x[:, l_number, :]
            out = network(current_layer_embeddings)
            outputs.append(out)
        
        outputs = torch.stack(outputs, dim = 1)
        
        return outputs

    # Batch: (batch_size, layer_size, hidden_dim_size)
    def training_step(self, batch, batch_idx):
        # Retrieve the output from the network
        embedding = batch['embedding']
        gt = batch['gt']
        
        # Forward pass
        nonlinear_output = self.propagate_batch(embedding)
        
        # Create and log the losses
        losses = self.get_layer_losses(nonlinear_output, gt)
        average_loss = sum(losses) / len(losses)
        self.log_evaluation_metrics(losses, "loss", "train")
        
        # save the linear_output and gt for analysis on training end
        self.train_end_epoch_analysis.save_batch_data_for_analysis(nonlinear_output, gt)
        
        return average_loss

    def on_train_epoch_end(self):
        # Calculate the entropy and KNN values
        eval_metrics = self.train_end_epoch_analysis.calculate_eval_metrics()
        self.log_evaluation_metrics(eval_metrics["kde"], "entropy", "train")
        self.log_evaluation_metrics(eval_metrics["knn"], "knn", "train")
        
        # Calculate the information gain with balanced classes
        eval_metrics = self.train_end_epoch_analysis.calculate_eval_metrics(balanced=True)
        self.log_evaluation_metrics(eval_metrics["kde"], "entropy", "train_balanced")
        self.log_evaluation_metrics(eval_metrics["knn"], "knn", "train_balanced")
        
        # Create the scatter plots for the nonlinearity
        fig, _ = self.train_end_epoch_analysis.create_scatter_plots()
        if config.get('logging', 'use_wandb'):
            wandb.log({"train_plot": [wandb.Image(fig)]})
        else:
            fig.savefig("nonlinearity_scatter_plots.png")
            
        # Create the scatter plots for the nonlinearity with balanced classes
        fig, _ = self.train_end_epoch_analysis.create_scatter_plots(balanced=True)
        if config.get('logging', 'use_wandb'):
            wandb.log({"train_plot_balanced": [wandb.Image(fig)]})
        else:
            fig.savefig("nonlinearity_scatter_plots_balanced.png")
        
        # Reset the batch data for the next epoch
        self.train_end_epoch_analysis.reset_batch_data_for_analysis()
        
        # Increment the epoch counter
        self.epoch += 1

    def validation_step(self, batch, batch_idx):
        # Retrieve the output from the network
        anchor = batch['anchor']
        positive_pair = batch['positive_pair']
        negative_pair = batch['negative_pair']
        gt = batch['gt']
        
        # Forward pass
        anchor_x = self.propagate_batch(anchor)
        positive_x = self.propagate_batch(positive_pair)
        negative_x = self.propagate_batch(negative_pair)
        
        # Create and log the losses
        losses = self.get_layer_losses(nonlinear_output, gt)
        average_loss = sum(losses) / len(losses)
        self.log_evaluation_metrics(losses, "loss", "val")
        
        # save the linear_output and gt for analysis on validation end
        self.val_end_epoch_analysis.save_batch_data_for_analysis(nonlinear_output, gt)
        
        return {
            "loss": average_loss,
            "nonlinear_output": nonlinear_output
        }
    
    def on_validation_epoch_end(self):
        # Calculate the entropy and KNN values
        eval_metrics = self.val_end_epoch_analysis.calculate_eval_metrics()
        self.log_evaluation_metrics(eval_metrics["kde"], "entropy", "val")
        self.log_evaluation_metrics(eval_metrics["knn"], "knn", "val")
        
        # Calculate the information gain with balanced classes
        eval_metrics = self.val_end_epoch_analysis.calculate_eval_metrics(balanced=True)
        self.log_evaluation_metrics(eval_metrics["kde"], "entropy", "val_balanced")
        self.log_evaluation_metrics(eval_metrics["knn"], "knn", "val_balanced")
        
        # Create the scatter plots for the nonlinearity
        fig, _ = self.val_end_epoch_analysis.create_scatter_plots()
        if config.get('logging', 'use_wandb'):
            wandb.log({"val_plot": [wandb.Image(fig)]})
        else:
            fig.savefig("nonlinearity_scatter_plots.png")
            
        # Create the scatter plots for the nonlinearity with balanced classes
        fig, _ = self.val_end_epoch_analysis.create_scatter_plots(balanced=True)
        if config.get('logging', 'use_wandb'):
            wandb.log({"val_plot_balanced": [wandb.Image(fig)]})
        else:
            fig.savefig("nonlinearity_scatter_plots_balanced.png")
        
        # Reset the batch data for the next epoch
        self.val_end_epoch_analysis.reset_batch_data_for_analysis()

    def test_step(self, batch, batch_idx):
        pass
    
    # Batch: (batch_size, layer_size, hidden_dim_size)
    # Output: (batch_size, layer_size, hidden_dim_size)
    # Propagate the batch through the network. Return the nonlinearity output.
    def propagate_batch(self, embedding):
        x = self.forward(embedding)
        return x
    
    def get_layer_losses(self, pred, gt):
        (pred1, pred2), gt_matches = self.get_data_pairs(pred, gt)
        
        llm_n_layers = pred1.shape[1]
        
        loss_list = []
        
        for i in range(llm_n_layers):
            loss = self.criterion(pred1[:, i, :], pred2[:, i, :], gt_matches)
            loss_list.append(loss)   
        
        return loss_list
        
    # Puts the batch into pairs for the contrastive loss
    def get_data_pairs(self, data, gt):
        assert data.dim() == 3
        assert data.shape[0] % 2 == 0, "Batch size must be even"
        
        # 1) Split the data into two halves
        batch_size = data.shape[0]
        
        half_size = batch_size // 2
        data1 = data[:half_size]
        data2 = data[half_size:]
        
        # 2) Create the ground truth matches
        gt1 = gt[:half_size]
        gt2 = gt[half_size:]
        
        # If the ground truth matches, the label is 1. Otherwise, it is 0.
        gt_matches = (gt1 == gt2).float()
        
        return (data1, data2), gt_matches
    
    ###### Logging ######
    def log_evaluation_metrics(self, values, metric_type: str, split: str):
        avg_value = sum(values) / len(values)
        self.log(f'avg_{split}_{metric_type}', avg_value)
        
        # Log individual layer values
        metrics = {f'{split}_{metric_type}/layer_{str(i+1).zfill(2)}': float(value) for i, value in enumerate(values)}
        wandb.log(metrics)

            
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=config.get('training', 'learning_rate'))
        return optimizer
