
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from torch.optim.lr_scheduler import ExponentialLR
from transformers import AutoModel
#from torchmetrics import Accuracy, F1
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score



from .config import config
from ..util import load_LORA_model

import math

import matplotlib.pyplot as plt
import torch
from sklearn.decomposition import PCA
import wandb

from sklearn.metrics import f1_score
from torchmetrics import Accuracy
from transformers import AutoTokenizer, AutoModel

class ClassifierModel(pl.LightningModule):
    def __init__(self):
        super(ClassifierModel, self).__init__()
        
        self.run_type = config.get("training_class", "type")
        if self.run_type == "llm":
            self.dim_size = 4096
        elif self.run_type == "proj":
            self.dim_size = 128
        elif self.run_type == "both":
            self.dim_size = 4096 + 128
        else:
            raise ValueError(f"Invalid run type: {self.run_type}")
        
        self.num_classes = 3  # Assuming 3 classes for the classification task
        
        self.classifier_model = self.initialize_classifier()

        # Metrics
        self.train_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)

    def initialize_classifier(self):
        # Define a simple MLP with 3 hidden layers, including Dropout and BatchNorm
        return nn.Sequential(
            nn.Linear(self.dim_size, 128),  # Input layer
            nn.BatchNorm1d(128),  # Batch normalization after the linear layer
            nn.ReLU(),
            nn.Dropout(0.3),  # Dropout with a probability of 0.3

            nn.Linear(128, 64),  # Hidden layer 1
            nn.BatchNorm1d(64),  # Batch normalization
            nn.ReLU(),
            nn.Dropout(0.3),  # Dropout with a probability of 0.3

            nn.Linear(64, self.num_classes)  # Output layer (no BatchNorm or Dropout here)
        )

    
    def forward(self, x):
        # Forward pass through the network
        return self.classifier_model(x)
    
    def training_step(self, batch, batch_idx):
        # Training step
        inputs, labels = batch
        logits = self(inputs)
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(logits, labels)
        
        # Compute accuracy
        acc = self.train_accuracy(logits, labels)
        
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Validation step
        inputs, labels = batch
        logits = self(inputs)
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(logits, labels)
        
        # Compute accuracy
        acc = self.val_accuracy(logits, labels)
        
        # Compute F1 score
        preds = torch.argmax(logits, dim=1)
        f1 = f1_score(labels.cpu(), preds.cpu(), average='weighted')
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_f1', f1, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        # Define optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

class LlamaEmbeddingModel(pl.LightningModule):
    def __init__(self):
        super(LlamaEmbeddingModel, self).__init__()

        self.projector_model = self.initialize_projector()
        self.temperature = config.get("training", "temperature")
        self.loss_type = config.get("training", "loss_type")  # "mse" or "contrastive"

        # Define the MSE loss function
        self.mse_loss_fn = nn.MSELoss()
        
        self.num_classes = 3  # Number of categories/classes -> RIse, fall , neutral
        
        self.val_outputs = []
        self.test_outputs = []
            
    def initialize_projector(self):
        hidden_size = config.get("training", "hidden_size")  # Should match the size of precomputed embeddings
        num_hidden_layers = config.get("training", "num_hidden_layers")  # Hyperparameter for number of hidden layers
        hidden_layer_size = config.get("training", "hidden_layer_size")  # Size of each hidden layer
        use_batch_norm = config.get("training", "use_batch_norm")  # Whether to use batch normalization
        dropout_rate = 0.1  # Dropout rate

        layers = []
        input_size = hidden_size

        for _ in range(num_hidden_layers):
            # Add a hidden layer followed by ReLU activation
            layers.append(nn.Linear(input_size, hidden_layer_size))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(hidden_layer_size))  # Add batch normalization after each hidden layer
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))  # Add dropout with rate 0.1 after each hidden layer
            input_size = hidden_layer_size

        # Final output layer
        layers.append(nn.Linear(input_size, 128))

        projector_model = nn.Sequential(*layers)
        return projector_model



    def forward(self, embeddings):
        """
        Forward pass through the projector network.

        Parameters:
        - embeddings: Tensor of shape (batch_size, hidden_size)

        Returns:
        - projected_embeddings: Tensor of shape (batch_size, 128)
        """
        # Project to lower-dimensional space
        projected_embeddings = self.projector_model(embeddings)

        # Normalize the embeddings
        projected_embeddings = F.normalize(projected_embeddings, p=2, dim=-1)

        return projected_embeddings
    
    def compute_loss(self, anchor_proj, augmented_proj, similiarity_scores, split : str):
        
        if self.loss_type == "mse":
            loss = self.calculate_mse_loss(anchor_proj, augmented_proj, similiarity_scores)
            self.log(f'{split}_mse_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        elif self.loss_type == "contrastive":
            loss = self.calculate_weighted_contrastive_loss(anchor_proj, augmented_proj, similiarity_scores)
            self.log(f'{split}_contrastive_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        elif self.loss_type == "both":
            mse_loss = self.calculate_mse_loss(anchor_proj, augmented_proj, similiarity_scores)
            contrastive_loss = self.calculate_weighted_contrastive_loss(anchor_proj, augmented_proj, similiarity_scores)
            loss = mse_loss + contrastive_loss
            self.log(f'{split}_mse_loss', mse_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log(f'{split}_contrastive_loss', contrastive_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        else:
            raise ValueError(f"Invalid loss type: {self.loss_type}")

        return loss


    def training_step(self, batch, batch_idx):
        anchor_embeddings = batch['anchor_embedding']  # Shape: (batch_size, hidden_size)
        augmented_embeddings = batch['augmented_embeddings']  # Shape: (batch_size, num_augmented, hidden_size)
        augmented_similarities = batch['augmented_similarities']  # Shape: (batch_size, num_augmented)

        batch_size, num_augmented, hidden_size = augmented_embeddings.size()

        # Process anchor embeddings through the projector
        anchor_proj = self.forward(anchor_embeddings)  # Shape: (batch_size, 128)

        # Flatten augmented embeddings and process through the projector
        augmented_embeddings_flat = augmented_embeddings.view(-1, hidden_size)  # Shape: (batch_size * num_augmented, hidden_size)
        augmented_proj_flat = self.forward(augmented_embeddings_flat)  # Shape: (batch_size * num_augmented, 128)

        # Reshape augmented projections back to (batch_size, num_augmented, 128)
        augmented_proj = augmented_proj_flat.view(batch_size, num_augmented, -1)

        # Compute loss based on the loss type
        loss = self.compute_loss(anchor_proj, augmented_proj, augmented_similarities, "train")

        return loss


    def validation_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx, stage='val')

        # Get anchor projections and labels
        anchor_embeddings = batch['anchor_embedding']
        anchor_proj = self.forward(anchor_embeddings)
        anchor_label = batch['anchor_label']

        # Return loss, embeddings, labels
        for i in range(anchor_embeddings.size(0)):
            if anchor_label[i] > -0.5:
                # TEMP:
                # Make the achor proj a completely random normal vector
                anchor_proj = torch.randn_like(anchor_proj)
                anchor_proj = F.normalize(anchor_proj, p=2, dim=-1)
                self.val_outputs.append({'embeddings': anchor_proj[i], 'labels': anchor_label[i]})
                
                #self.val_outputs.append({'embeddings': anchor_proj[i], 'labels': anchor_label[i]})
                
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx, stage='test')

        # Get anchor projections and labels
        anchor_embeddings = batch['anchor_embedding']
        anchor_proj = self.forward(anchor_embeddings)
        anchor_label = batch['anchor_label']

        # Return loss, embeddings, labels
        for i in range(anchor_embeddings.size(0)):
            if anchor_label[i] > -0.5:
                
                self.test_outputs.append({'embeddings': anchor_proj[i], 'labels': anchor_label[i]})
                
        return loss


    def _shared_eval_step(self, batch, batch_idx, stage):
        anchor_embeddings = batch['anchor_embedding']  # Shape: (batch_size, hidden_size)
        augmented_embeddings = batch['augmented_embeddings']  # Shape: (batch_size, num_augmented, hidden_size)
        augmented_similarities = batch['augmented_similarities']  # Shape: (batch_size, num_augmented)

        batch_size, num_augmented, hidden_size = augmented_embeddings.size()

        # Process anchor embeddings through the projector
        anchor_proj = self.forward(anchor_embeddings)  # Shape: (batch_size, 128)

        # Flatten augmented embeddings and process through the projector
        augmented_embeddings_flat = augmented_embeddings.view(-1, hidden_size)  # Shape: (batch_size * num_augmented, hidden_size)
        augmented_proj_flat = self.forward(augmented_embeddings_flat)  # Shape: (batch_size * num_augmented, 128)

        # Reshape augmented projections back to (batch_size, num_augmented, -1)
        augmented_proj = augmented_proj_flat.view(batch_size, num_augmented, -1)

        loss = self.compute_loss(anchor_proj, augmented_proj, augmented_similarities, stage)

        return loss

    def calculate_weighted_contrastive_loss(self, anchor_proj, augmented_proj, similiarity_scores):
        """
        Calculates the weighted supervised contrastive loss as defined in Equation (2) of the paper.

        Parameters:
        - anchor_proj: Tensor of shape (batch_size, 128)
        - augmented_proj: Tensor of shape (batch_size, num_augmented, 128)
        - similarity_scores_flat: Tensor of shape (batch_size, num_augmented,)

        Returns:
        - loss: A scalar tensor representing the loss
        """
        batch_size, num_augmented, _ = augmented_proj.size()

        # Normalize projections
        anchor_proj = F.normalize(anchor_proj, p=2, dim=-1)  # Shape: (batch_size, 128)
        augmented_proj = F.normalize(augmented_proj, p=2, dim=-1)  # Shape: (batch_size, num_augmented, 128)

        # Compute cosine similarities: (batch_size, num_augmented)
        # Reshape anchor_proj to (batch_size, 1, 128) so that the dimensions match for batch-wise multiplication
        anchor_proj = anchor_proj.unsqueeze(1)  # Shape: (batch_size, 1, 128)

        # Now compute the dot product
        cosine_similarities = torch.matmul(anchor_proj, augmented_proj.transpose(1, 2)).squeeze(1)  # Shape: (batch_size, num_augmented)

        # Scale similarities by temperature
        similarities = cosine_similarities / self.temperature  # Shape: (batch_size, num_augmented)

        # For each anchor, the augmented projections are the positives
        # Create target logits by exponentiating the similarities
        exp_similarities = torch.exp(similarities)  # Shape: (batch_size, num_augmented)

        # Compute log-softmax for each anchor over its augmented projections
        log_prob = similarities #- torch.log(exp_similarities.sum(dim=1, keepdim=True))  # Shape: (batch_size, num_augmented)

        # Flatten the log_prob and similarity_scores
        log_prob_flat = log_prob.view(-1)  # Shape: (batch_size * num_augmented,)
        similarity_scores_flat = similiarity_scores.view(-1).to(log_prob_flat.device)  # Ensure device match

        # Compute the weighted loss
        loss = -similarity_scores_flat * log_prob_flat  # Element-wise multiplication
        loss = loss.mean()

        return loss

    def calculate_mse_loss(self, anchor_proj, augmented_proj, similarity_scores):
        """
        Calculates the Mean Squared Error (MSE) loss as defined in Equation (1) of the paper.

        Parameters:
        - anchor_proj: Tensor of shape (batch_size, 128)
        - augmented_proj: Tensor of shape (batch_size, num_augmented, 128)
        - similarity_scores Tensor of shape (batch_size * num_augmented,)

        Returns:
        - loss: A scalar tensor representing the MSE loss
        """
        # Compute cosine similarities between anchor and augmented projections
        cosine_similarities = F.cosine_similarity(anchor_proj.unsqueeze(1), augmented_proj, dim=-1)  # Shape: (batch_size, num_augmented)

        # Clamp cosine similarities to be non-negative
        cosine_similarities = torch.clamp(cosine_similarities, min=0.0)  # Shape: (batch_size, num_augmented)

        # Flatten the cosine similarities and similarity scores
        cosine_similarities_flat = cosine_similarities.view(-1)  # Shape: (batch_size * num_augmented,)
        similarity_scores_flat = similarity_scores.view(-1).to(cosine_similarities_flat.device)  # Ensure device match

        # Compute the MSE loss
        loss = F.mse_loss(cosine_similarities_flat, similarity_scores_flat)

        return loss
    
    def on_train_epoch_end(self):
        # Get the optimizer's learning rate
        optimizer = self.trainer.optimizers[0]
        lr = optimizer.param_groups[0]['lr']  # Assuming one parameter group
        
        # Log the learning rate using self.log
        self.log('learning_rate', lr)
    
    def on_validation_epoch_end(self):
        outputs = self.val_outputs

        # Collect embeddings and labels
        embeddings = torch.stack([x['embeddings'] for x in outputs]).to("cuda:0")  # Shape: (N, embedding_dim)
        labels = torch.tensor([int(x['labels']) for x in outputs]).to("cuda:0")  # Shape: (N,)

        # Compute metrics
        nn_accuracy = self.compute_nearest_neighbor_accuracy(embeddings, labels)
        self.log('val_nn_accuracy', nn_accuracy)

        info_gain = self.compute_information_gain(embeddings, labels)
        self.log('val_information_gain', info_gain)

        kl_divergence = self.compute_kl_divergence(embeddings, labels)
        self.log('val_kl_divergence', kl_divergence)

        jsd = self.compute_jsd(embeddings, labels)
        self.log('val_jsd', jsd)
        
        # Create PCA plot for embeddings
        if config.get("training", "wandb"):
            fig = self.plot_embeddings_pca(embeddings, labels)
            wandb.log({"val_embeddings": wandb.Image(fig)})
            plt.close(fig)
        
        # Reset the outputs
        self.val_outputs = []

    def on_test_epoch_end(self, outputs):
        outputs = self.test_outputs

        # Collect embeddings and labels
        embeddings = torch.stack([x['embeddings'] for x in outputs]).to("cuda:0")  # Shape: (N, embedding_dim)
        labels = torch.tensor([int(x['labels']) for x in outputs]).to("cuda:0")  # Shape: (N,)

        # Compute metrics
        nn_accuracy = self.compute_nearest_neighbor_accuracy(embeddings, labels)
        self.log('test_nn_accuracy', nn_accuracy)

        info_gain = self.compute_information_gain(embeddings, labels)
        self.log('test_information_gain', info_gain)

        kl_divergence = self.compute_kl_divergence(embeddings, labels)
        self.log('test_kl_divergence', kl_divergence)

        jsd = self.compute_jsd(embeddings, labels)
        self.log('test_jsd', jsd)
        
        # Create PCA plot for embeddings
        if config.get("training", "wandb"):
            fig = self.plot_embeddings_pca(embeddings, labels)
            wandb.log({"test_embeddings": wandb.Image(fig)})
            plt.close(fig)
        
        # Reset the outputs
        self.test_outputs = []
        
    def plot_embeddings_pca(self, embeddings, labels):
        """
        Reduces embeddings to 2D using PCA and creates a scatter plot.
        
        Args:
            embeddings (torch.Tensor): Tensor of shape (N, embedding_dim).
            labels (torch.Tensor): Tensor of shape (N,) containing the labels for color-coding.
            
        Returns:
            plt.Figure: Matplotlib figure with the PCA scatter plot.
        """
        # Convert to numpy if needed
        if isinstance(embeddings, torch.Tensor):
            embeddings = embeddings.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()

        # Perform PCA to reduce embeddings to 2 dimensions
        pca = PCA(n_components=2)
        embeddings_2d = pca.fit_transform(embeddings)

        # Define colors for each class (0: Red, 1: Amber, 2: Green)
        colors = {0: 'red', 1: '#FFBF00', 2: 'Green'}
        
        label2category = {0: 'Fall', 1: 'Neutral', 2: 'Rise'}
        
        # Create a scatter plot
        plt.figure(figsize=(8, 6))
        for label in np.unique(labels):
            idx = labels == label
            plt.scatter(embeddings_2d[idx, 0], embeddings_2d[idx, 1], 
                        c=colors[label], label=f'Class {label2category[label]}', 
                        alpha=0.6, edgecolor='k', linewidth=0.5)
            
        plt.title('PCA of Embeddings')
        plt.xlabel('PCA 1')
        plt.ylabel('PCA 2')
        plt.legend(loc='best')
        
        # Return the figure
        fig = plt.gcf()
        return fig

    def compute_nearest_neighbor_accuracy(self, embeddings, labels):
        N, D = embeddings.size()

        # Compute cosine similarity matrix
        similarity_matrix = torch.matmul(embeddings, embeddings.T)  # Shape: (N, N)

        # Exclude self-similarities
        mask = torch.eye(N, device=embeddings.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))

        # Find nearest neighbors
        nn_indices = similarity_matrix.argmax(dim=1)  # Shape: (N,)

        # Get labels of nearest neighbors
        nn_labels = labels[nn_indices]

        # Compute accuracy
        correct = (nn_labels == labels)
        accuracy = correct.float().mean()

        return accuracy

        
    def compute_label_distributions(self, embeddings, labels, k=5):
        N, D = embeddings.size()

        # Compute cosine similarity matrix
        similarity_matrix = torch.matmul(embeddings, embeddings.T)  # Shape: (N, N)

        # Exclude self-similarities
        mask = torch.eye(N, device=embeddings.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))

        # Find k nearest neighbors
        _, topk_indices = torch.topk(similarity_matrix, k=k, dim=1)  # Shape: (N, k)

        # Get labels of k nearest neighbors
        neighbor_labels = labels[topk_indices]  # Shape: (N, k)

        # Compute label distributions
        num_classes = self.num_classes
        label_distributions = torch.zeros(N, num_classes, device=embeddings.device)
        for c in range(num_classes):
            label_distributions[:, c] = (neighbor_labels == c).sum(dim=1)

        # Normalize to get probabilities
        label_distributions = label_distributions / k

        return label_distributions
    
    def compute_global_label_distribution(self, labels):
        num_classes = self.num_classes
        N = labels.size(0)

        label_counts = torch.zeros(num_classes, device=labels.device)
        for c in range(num_classes):
            label_counts[c] = (labels == c).sum()

        global_distribution = label_counts / N

        return global_distribution
    
    def compute_information_gain(self, embeddings, labels, k=5):
        label_distributions = self.compute_label_distributions(embeddings, labels, k)
        entropies = self.compute_local_entropies(label_distributions)
        mean_entropy = entropies.mean()
        H_max = math.log2(self.num_classes)
        info_gain = H_max - mean_entropy
        return info_gain


    def compute_local_entropies(self, label_distributions):
        epsilon = 1e-10  # Avoid log(0)
        p = label_distributions + epsilon
        entropies = - (label_distributions * torch.log2(p)).sum(dim=1)
        return entropies


    def compute_kl_divergence(self, embeddings, labels, k=5):
        label_distributions = self.compute_label_distributions(embeddings, labels, k)
        global_distribution = self.compute_global_label_distribution(labels)

        epsilon = 1e-10
        p = label_distributions + epsilon
        q = global_distribution.unsqueeze(0) + epsilon  # Shape: (1, num_classes)
        kl_divs = (p * (p / q).log2()).sum(dim=1)
        mean_kl_divergence = kl_divs.mean()

        return mean_kl_divergence
    
    def compute_jsd(self, embeddings, labels, k=5):
        label_distributions = self.compute_label_distributions(embeddings, labels, k)
        global_distribution = self.compute_global_label_distribution(labels)

        epsilon = 1e-10
        p = label_distributions + epsilon
        q = global_distribution.unsqueeze(0) + epsilon
        m = 0.5 * (p + q)

        kl_p_m = (p * (p / m).log2()).sum(dim=1)
        kl_q_m = (q * (q / m).log2()).sum(dim=1)

        jsd = 0.5 * (kl_p_m + kl_q_m)
        mean_jsd = jsd.mean()

        return mean_jsd

    def configure_optimizers(self):
        lr = config.get("training", "lr")
        gamma = config.get("training", "gamma")

        optimizer = Adam(self.parameters(), lr=lr)

        # Set up an ExponentialLR scheduler
        scheduler = ExponentialLR(optimizer, gamma=gamma)

        # Return both the optimizer and the scheduler
        return [optimizer], [scheduler]

class SimilaritySpaceModel(pl.LightningModule):
    def __init__(self):
        super(SimilaritySpaceModel, self).__init__()

        self.encoder_model, self.projector_model = self.initialize_nn()
        self.temperature = config.get("training", "temperature")

        # Enable gradient checkpointing
        self.encoder_model.gradient_checkpointing_enable()

    # Model Initialization and Setup
    def initialize_nn(self):
        base_model = config.get("training", "base_model")

        # Load the encoder model (LLAMA or other models with LORA)
        encoder_model, _ = load_LORA_model(base_model)

        # Initialize the projector model with the same dtype as the encoder model
        encoder_dtype = next(encoder_model.parameters()).dtype
        projector_model = nn.Sequential(
            nn.Linear(encoder_model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
        ).to(encoder_dtype)

        return encoder_model, projector_model

    # (Freeze layers function remains unchanged)

    def forward(self, input_ids, attention_mask, return_all_hidden_states=False):
        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            encoder_output = self.encoder_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
                output_attentions=True,
            )

            last_hidden_state = encoder_output.hidden_states[-1][:, -1, :]

            # Project to lower-dimensional space
            embedding = self.projector_model(last_hidden_state)

            # Normalize the embeddings
            embedding = F.normalize(embedding, p=2, dim=-1)
            
        if return_all_hidden_states:
            return embedding, encoder_output.hidden_states, encoder_output.attentions
        
        # Optionally, but the outputs to cpu, keeping the gradient them to save memory
        if config.get("training", "outputs_to_cpu"):
            embedding = embedding.cpu()
            encoder_output.hidden_states = [hs.cpu() for hs in encoder_output.hidden_states]
            encoder_output.attentions = [attn.cpu() for attn in encoder_output.attentions]

        return embedding

    def propagate_batch(self, batch, return_all_hidden_states=False):
        # Process anchor embedding
        anchor_embedding = self(
            batch["anchor_ids"], batch["anchor_attention_mask"]
        )  # Shape: (batch_size, emb_dim)

        # Initialize list to store embeddings and similarity scores
        augmented_embeddings = []
        similarity_scores = []

        # Process each augmented sample individually
        for i in range(len(batch["augmented_ids"][0])):
            augmented_ids = batch["augmented_ids"][0, i].unsqueeze(0)  # Shape: (1, seq_len)
            augmented_attention_mask = batch["augmented_attention_mask"][0, i].unsqueeze(0)

            # Compute augmented embedding
            augmented_embedding = self(augmented_ids, augmented_attention_mask)

            # Append to list
            augmented_embeddings.append(augmented_embedding)
            similarity_scores.append(1 - batch["augmented_distances"][0][i])

            # Free up memory
            del augmented_ids, augmented_attention_mask, augmented_embedding
            torch.cuda.empty_cache()
            
        # Stack embeddings and similarity scores
        augmented_embeddings = torch.stack(augmented_embeddings)  # Shape: (num_augmented, emb_dim)
        similarity_scores = torch.stack(similarity_scores)  # Shape: (num_augmented,)

        return anchor_embedding, augmented_embeddings, similarity_scores

    def calculate_loss(self, anchor_embedding, augmented_embeddings, similarity_scores):
        """
        Calculates the supervised contrastive loss with continuous similarity scores.

        Parameters:
        - anchor_embedding: Tensor of shape (emb_dim,)
        - augmented_embeddings: Tensor of shape (num_augmented, emb_dim)
        - similarity_scores: Tensor of shape (num_augmented,)

        Returns:
        - loss: A scalar tensor representing the loss
        """

        # Compute cosine similarities
        similarities = F.cosine_similarity(
            augmented_embeddings, anchor_embedding.unsqueeze(0), dim=1
        ) / self.temperature  # Shape: (num_augmented,)

        # Compute logits for positive pairs
        positive_logits = similarities * similarity_scores

        # For negatives, you might need to collect negative samples separately
        # In this case, we'll assume negatives are other samples in the batch (batch size is 1)
        # So we won't have negatives unless you implement a memory bank or other mechanism

        # Since batch size is 1, we can only compute the loss over the augmented samples
        logits = positive_logits  # Shape: (num_augmented,)

        # Labels: since all are positive samples, labels are zeros
        labels = torch.zeros(len(logits), dtype=torch.long, device=anchor_embedding.device)

        # Compute cross-entropy loss
        loss = F.cross_entropy(logits.unsqueeze(0), labels.unsqueeze(0))

        return loss

    def training_step(self, batch, batch_idx):
        # Get embeddings and similarity scores
        anchor_embedding, augmented_embeddings, similarity_scores = self.propagate_batch(batch)

        # Calculate loss
        loss = self.calculate_loss(anchor_embedding, augmented_embeddings, similarity_scores)

        # Free up memory
        del anchor_embedding, augmented_embeddings, similarity_scores
        torch.cuda.empty_cache()

        # Log training loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    # The validation_step and test_step would be similar to training_step

    def configure_optimizers(self):
        lr = config.get("training", "lr")
        gamma = config.get("training", "gamma", fallback=0.9)  # Decay factor, default 0.9

        optimizer = Adam(self.parameters(), lr=lr)

        # Set up an ExponentialLR scheduler
        scheduler = ExponentialLR(optimizer, gamma=gamma)

        # Return both the optimizer and the scheduler
        return [optimizer], [scheduler]


    
        # This function will look at the training configuration and freeze the appropriate layers, according
    # to which layers should be frozen, and what the current epoch is.
    def freeze_appropriate_layers(self, epoch):
        lora_freeze_dict = config.get("training","lora_freeze_dict")

        for param in self.encoder_model.parameters():
            param.requires_grad = False

        # Get the closest epoch key less than or equal to the current epoch
        freeze_key = max([key for key in lora_freeze_dict.keys() if key <= epoch], default=None)

        # If no freeze_key is found, freeze all LoRA layers
        if freeze_key is None:
            layers_to_unfreeze = []
        else:
            # Get the value associated with the freeze_key
            layers_to_unfreeze = lora_freeze_dict[freeze_key]
            
        # List all LoRA layers
        lora_layers = [name for name, param in self.encoder_model.named_parameters() if "lora" in name]
        
        # Getting the indicies of all of the lora layers
        lora_layer_idx = sorted([int(name.split(".")[2]) for name in lora_layers])
        lora_layer_idx = list(set(lora_layer_idx))
        
        if layers_to_unfreeze == 'all':
            # Unfreeze all LoRA layers
            indicies_to_unfreeze = lora_layer_idx
            print(f"Epoch {epoch}: Unfreezing all LoRA layers")
        elif isinstance(layers_to_unfreeze, int):
            # Unfreeze the last `layers_to_unfreeze` layers
            indicies_to_unfreeze = lora_layer_idx[-layers_to_unfreeze:]
            print(f"Epoch {epoch}: Unfreezing the last {len(indicies_to_unfreeze)} LoRA layers")
        else:
            # No layers to unfreeze
            indicies_to_unfreeze = []
            print(f"Epoch {epoch}: Freezing all LoRA layers")
        
        # Set the appropriate layers to be trainable
        for name, param in self.encoder_model.named_parameters():
            if "lora" not in name:
                continue
            
            layer_idx = int(name.split(".")[2])
            
            if layer_idx in indicies_to_unfreeze:
                param.requires_grad = True

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import AutoModel
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
import math
# Assume that 'config' and 'wandb' are properly imported or defined elsewhere

class EmbeddingModel(pl.LightningModule):
    def __init__(self):
        super(EmbeddingModel, self).__init__()

        self.encoder_model, self.projector_model = self.initialize_nn()
        self.temperature = config.get("embedding_training", "temperature")
        self.loss_type = config.get("training", "loss_type")  # "mse", "contrastive", "contrastive_exponential", or "both"
        self.num_classes = 3  # Number of categories/classes -> Rise, Fall, Neutral
        self.margin = 1.0  # Margin for contrastive loss
        self.tau = config.get("training", "tau")  # Temperature parameter for exponential loss

        # Enable gradient checkpointing for memory efficiency
        self.encoder_model.gradient_checkpointing_enable()

        # Define the MSE loss function
        self.mse_loss_fn = nn.MSELoss()
        
        # Containers for validation and test outputs
        self.val_outputs = []
        self.test_outputs = []

    # Model Initialization and Setup
    def initialize_nn(self):
        base_model = config.get("embedding_training", "encoder_base_model")

        # Load the encoder model (e.g., LLAMA or other models with LoRA)
        self.encoder_config = AutoModel.from_pretrained(base_model).config
        encoder_model = AutoModel.from_pretrained(base_model)
        
        encoder_model.train()

        # Initialize the projector model with the same dtype as the encoder model
        encoder_dtype = next(encoder_model.parameters()).dtype
        projector_model = nn.Sequential(
            nn.Linear(encoder_model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
        ).to(encoder_dtype)

        return encoder_model, projector_model

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        """
        Forward pass through the encoder and projector networks.

        Parameters:
        - input_ids: Tensor of shape (batch_size, seq_length)
        - attention_mask: Tensor of shape (batch_size, seq_length)
        - token_type_ids: Tensor of shape (batch_size, seq_length) or None

        Returns:
        - projected_embeddings: Tensor of shape (batch_size, 128)
        """
        # Get encoder outputs
        encoder_outputs = self.encoder_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=False,
            return_dict=True,
        )

        # Check if 'pooler_output' is available
        if hasattr(encoder_outputs, 'pooler_output') and encoder_outputs.pooler_output is not None:
            # Use the pooled output directly
            pooled_output = encoder_outputs.pooler_output  # Shape: (batch_size, hidden_size)
        else:
            # If 'pooler_output' is not available, apply mean pooling on 'last_hidden_state'
            # Alternatively, you can use the hidden state corresponding to the [CLS] token
            # For mean pooling:
            pooled_output = encoder_outputs.last_hidden_state.mean(dim=1)  # Shape: (batch_size, hidden_size)
            # For [CLS] token (assuming it's the first token):
            # pooled_output = encoder_outputs.last_hidden_state[:, 0, :]  # Shape: (batch_size, hidden_size)

        # Pass through the projector
        projected_embeddings = self.projector_model(pooled_output)  # Shape: (batch_size, 128)

        # Normalize the embeddings
        projected_embeddings = F.normalize(projected_embeddings, p=2, dim=-1)

        return projected_embeddings

    def compute_loss(self, anchor_proj, augmented_proj, similarity_scores, split):
        if self.loss_type == "mse":
            loss = self.calculate_mse_loss(anchor_proj, augmented_proj, similarity_scores)
            self.log(f'{split}_mse_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        elif self.loss_type == "contrastive":
            loss = self.calculate_weighted_contrastive_loss(anchor_proj, augmented_proj, similarity_scores)
            self.log(f'{split}_contrastive_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        elif self.loss_type == "contrastive_exponential":
            loss = self.calculate_weighted_exponential_loss(anchor_proj, augmented_proj, similarity_scores, self.tau)
            self.log(f'{split}_contrastive_exponential_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        elif self.loss_type == "both":
            mse_loss = self.calculate_mse_loss(anchor_proj, augmented_proj, similarity_scores)
            contrastive_loss = self.calculate_weighted_contrastive_loss(anchor_proj, augmented_proj, similarity_scores)
            loss = mse_loss + contrastive_loss
            self.log(f'{split}_mse_loss', mse_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            self.log(f'{split}_contrastive_loss', contrastive_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        else:
            raise ValueError(f"Invalid loss type: {self.loss_type}")

        return loss

    def training_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']  # Shape: (batch_size, seq_length)
        base_attention_mask = batch['base_attention_mask']  # Shape: (batch_size, seq_length)
        base_token_type_ids = batch['base_token_type_ids']  # Shape: (batch_size, seq_length) or None

        augmented_input_ids = batch['augmented_input_ids']  # Shape: (batch_size, num_augmented, seq_length)
        augmented_attention_mask = batch['augmented_attention_mask']  # Shape: (batch_size, num_augmented, seq_length)
        augmented_token_type_ids = batch['augmented_token_type_ids']  # Shape: (batch_size, num_augmented, seq_length) or None

        similarity_scores = batch['similarity_scores']  # Shape: (batch_size, num_augmented)

        # Process anchor (base) embeddings
        anchor_proj = self.forward(
            input_ids=base_input_ids,
            attention_mask=base_attention_mask,
            token_type_ids=base_token_type_ids
        )  # Shape: (batch_size, 128)

        # Flatten augmented inputs for processing
        batch_size, num_augmented, seq_len = augmented_input_ids.size()
        augmented_input_ids_flat = augmented_input_ids.view(-1, seq_len)  # Shape: (batch_size * num_augmented, seq_length)
        augmented_attention_mask_flat = augmented_attention_mask.view(-1, seq_len)
        if augmented_token_type_ids is not None:
            augmented_token_type_ids_flat = augmented_token_type_ids.view(-1, seq_len)
        else:
            augmented_token_type_ids_flat = None

        # Process augmented embeddings
        augmented_proj_flat = self.forward(
            input_ids=augmented_input_ids_flat,
            attention_mask=augmented_attention_mask_flat,
            token_type_ids=augmented_token_type_ids_flat
        )  # Shape: (batch_size * num_augmented, 128)

        # Reshape back to (batch_size, num_augmented, 128)
        augmented_proj = augmented_proj_flat.view(batch_size, num_augmented, -1)

        # Compute loss
        loss = self.compute_loss(anchor_proj, augmented_proj, similarity_scores, "train")

        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx, stage='val')

        # Get anchor projections and labels
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch['base_token_type_ids']
        anchor_proj = self.forward(
            input_ids=base_input_ids,
            attention_mask=base_attention_mask,
            token_type_ids=base_token_type_ids
        )
        anchor_label = batch.get('anchor_label', None)

        if anchor_label is not None:
            for i in range(anchor_proj.size(0)):
                if anchor_label[i] > -0.5:
                    
                    # TEMP:  the embeddings to be completely random. Normalized random embeddings
                    #anchor_proj[i] = torch.randn_like(anchor_proj[i])
                    
                    self.val_outputs.append({'embeddings': anchor_proj[i], 'labels': anchor_label[i]})

        return loss

    def test_step(self, batch, batch_idx):
        loss = self._shared_eval_step(batch, batch_idx, stage='test')

        # Get anchor projections and labels
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch['base_token_type_ids']
        anchor_proj = self.forward(
            input_ids=base_input_ids,
            attention_mask=base_attention_mask,
            token_type_ids=base_token_type_ids
        )
        anchor_label = batch.get('anchor_label', None)

        if anchor_label is not None:
            for i in range(anchor_proj.size(0)):
                if anchor_label[i] > -0.5:
                    self.test_outputs.append({'embeddings': anchor_proj[i], 'labels': anchor_label[i]})

        return loss

    def _shared_eval_step(self, batch, batch_idx, stage):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch['base_token_type_ids']

        augmented_input_ids = batch['augmented_input_ids']
        augmented_attention_mask = batch['augmented_attention_mask']
        augmented_token_type_ids = batch['augmented_token_type_ids']

        similarity_scores = batch['similarity_scores']

        # Process anchor embeddings
        anchor_proj = self.forward(
            input_ids=base_input_ids,
            attention_mask=base_attention_mask,
            token_type_ids=base_token_type_ids
        )

        # Flatten augmented inputs
        batch_size, num_augmented, seq_len = augmented_input_ids.size()
        augmented_input_ids_flat = augmented_input_ids.view(-1, seq_len)
        augmented_attention_mask_flat = augmented_attention_mask.view(-1, seq_len)
        if augmented_token_type_ids is not None:
            augmented_token_type_ids_flat = augmented_token_type_ids.view(-1, seq_len)
        else:
            augmented_token_type_ids_flat = None

        # Process augmented embeddings
        augmented_proj_flat = self.forward(
            input_ids=augmented_input_ids_flat,
            attention_mask=augmented_attention_mask_flat,
            token_type_ids=augmented_token_type_ids_flat
        )

        # Reshape back
        augmented_proj = augmented_proj_flat.view(batch_size, num_augmented, -1)

        # Compute loss
        loss = self.compute_loss(anchor_proj, augmented_proj, similarity_scores, stage)

        return loss

    def calculate_weighted_exponential_loss(self, anchor_proj, augmented_proj, similarity_scores, tau):
        """
        Calculates the Continuously Weighted Contrastive Loss (CWCL).

        Parameters:
        - anchor_proj: Tensor of shape (batch_size, embedding_dim)
        - augmented_proj: Tensor of shape (batch_size, num_augmented, embedding_dim)
        - similarity_scores: Tensor of shape (batch_size, num_augmented)
        - tau: Temperature scaling parameter (scalar)

        Returns:
        - loss: A scalar tensor representing the loss
        """
        batch_size, num_augmented, embedding_dim = augmented_proj.size()

        # Expand anchor_proj to match augmented_proj dimensions
        anchor_proj_expanded = anchor_proj.unsqueeze(1).expand(-1, num_augmented, -1)  # Shape: (batch_size, num_augmented, embedding_dim)

        # Compute cosine similarities between anchor and augmented embeddings
        # Resulting shape: (batch_size, num_augmented)
        d_ij = F.cosine_similarity(anchor_proj_expanded, augmented_proj, dim=-1)

        # Compute numerator: exp(d_ij / tau)
        numerator = torch.exp(d_ij / tau)  # Shape: (batch_size, num_augmented)

        # Compute denominator: sum over all augmented samples for each anchor
        denominator = torch.sum(numerator, dim=1, keepdim=True)  # Shape: (batch_size, 1)

        # Compute the fraction
        frac = numerator / denominator  # Shape: (batch_size, num_augmented)

        # Compute the logarithm of the fraction
        log_frac = torch.log(frac + 1e-8)  # Adding a small epsilon to prevent log(0)

        # Compute the loss per pair: -s_ij * log(frac)
        loss_per_pair = - similarity_scores * log_frac  # Shape: (batch_size, num_augmented)

        # Compute the mean loss over all pairs
        loss = loss_per_pair.mean()

        return loss

    def calculate_weighted_contrastive_loss(self, anchor_proj, augmented_proj, similarity_scores):
        """
        Calculates the weighted contrastive loss.

        Parameters:
        - anchor_proj: Tensor of shape (batch_size, embedding_dim)
        - augmented_proj: Tensor of shape (batch_size, num_augmented, embedding_dim)
        - similarity_scores: Tensor of shape (batch_size, num_augmented)

        Returns:
        - loss: A scalar tensor representing the loss
        """
        batch_size, num_augmented, _ = augmented_proj.size()
        margin = self.margin  # Ensure self.margin is defined in __init__

        # Compute Euclidean distances between anchor and augmented embeddings
        # Expand anchor_proj to match augmented_proj dimensions
        anchor_proj_expanded = anchor_proj.unsqueeze(1).expand(-1, num_augmented, -1)  # Shape: (batch_size, num_augmented, embedding_dim)
        
        # Compute pairwise distances
        d_ij = torch.norm(anchor_proj_expanded - augmented_proj, p=2, dim=-1)  # Shape: (batch_size, num_augmented)

        # Compute pull loss: s_ij * d_ij^2
        pull_loss = similarity_scores * (d_ij ** 2)  # Shape: (batch_size, num_augmented)

        # Compute push loss: (1 - s_ij) * max(0, margin - d_ij)^2
        push_loss = (1 - similarity_scores) * (torch.clamp(margin - d_ij, min=0) ** 2)  # Shape: (batch_size, num_augmented)

        # Total loss per pair
        loss_per_pair = pull_loss + push_loss  # Shape: (batch_size, num_augmented)

        # Average over all pairs
        loss = loss_per_pair.mean()

        return loss

    def calculate_mse_loss(self, anchor_proj, augmented_proj, similarity_scores):
        """
        Calculates the Mean Squared Error (MSE) loss.

        Parameters:
        - anchor_proj: Tensor of shape (batch_size, 128)
        - augmented_proj: Tensor of shape (batch_size, num_augmented, 128)
        - similarity_scores: Tensor of shape (batch_size, num_augmented)

        Returns:
        - loss: A scalar tensor representing the MSE loss
        """
        # Compute cosine similarities between anchor and augmented projections
        cosine_similarities = F.cosine_similarity(anchor_proj.unsqueeze(1), augmented_proj, dim=-1)  # Shape: (batch_size, num_augmented)

        # Clamp cosine similarities to be non-negative
        cosine_similarities = torch.clamp(cosine_similarities, min=0.0)  # Shape: (batch_size, num_augmented)

        # Flatten the cosine similarities and similarity scores
        cosine_similarities_flat = cosine_similarities.view(-1)  # Shape: (batch_size * num_augmented,)
        similarity_scores_flat = similarity_scores.view(-1).to(cosine_similarities_flat.device)  # Ensure device match

        # Compute the MSE loss
        loss = F.mse_loss(cosine_similarities_flat, similarity_scores_flat)

        return loss

    def on_train_epoch_end(self):
        # Get the optimizer's learning rate
        optimizer = self.trainer.optimizers[0]
        lr = optimizer.param_groups[0]['lr']  # Assuming one parameter group

        # Log the learning rate
        self.log('learning_rate', lr)

    def on_validation_epoch_end(self):
        outputs = self.val_outputs

        # Collect embeddings and labels
        embeddings = torch.stack([x['embeddings'] for x in outputs])
        labels = torch.tensor([int(x['labels']) for x in outputs]).to(embeddings.device)

        # Compute metrics
        nn_accuracy = self.compute_nearest_neighbor_accuracy(embeddings, labels)
        self.log('val_nn_accuracy', nn_accuracy)

        info_gain = self.compute_information_gain(embeddings, labels)
        self.log('val_information_gain', info_gain)

        kl_divergence = self.compute_kl_divergence(embeddings, labels)
        self.log('val_kl_divergence', kl_divergence)

        jsd = self.compute_jsd(embeddings, labels)
        self.log('val_jsd', jsd)

        # Create PCA plot for embeddings
        if config.get("training", "wandb"):
            fig = self.plot_embeddings_pca(embeddings, labels)
            wandb.log({"val_embeddings": wandb.Image(fig)})
            plt.close(fig)

        # Reset the outputs
        self.val_outputs = []

    def on_test_epoch_end(self):
        outputs = self.test_outputs

        # Collect embeddings and labels
        embeddings = torch.stack([x['embeddings'] for x in outputs])
        labels = torch.tensor([int(x['labels']) for x in outputs])

        # Compute metrics
        nn_accuracy = self.compute_nearest_neighbor_accuracy(embeddings, labels)
        self.log('test_nn_accuracy', nn_accuracy)

        info_gain = self.compute_information_gain(embeddings, labels)
        self.log('test_information_gain', info_gain)

        kl_divergence = self.compute_kl_divergence(embeddings, labels)
        self.log('test_kl_divergence', kl_divergence)

        jsd = self.compute_jsd(embeddings, labels)
        self.log('test_jsd', jsd)

        # Create PCA plot for embeddings
        if config.get("training", "wandb"):
            fig = self.plot_embeddings_pca(embeddings, labels)
            wandb.log({"test_embeddings": wandb.Image(fig)})
            plt.close(fig)

        # Reset the outputs
        self.test_outputs = []

    def plot_embeddings_pca(self, embeddings, labels):
        """
        Reduces embeddings to 2D using PCA and creates a scatter plot.
        """
        embeddings = embeddings.cpu().numpy()
        labels = labels.cpu().numpy()

        # Perform PCA to reduce embeddings to 2 dimensions
        pca = PCA(n_components=2)
        embeddings_2d = pca.fit_transform(embeddings)
        
        # Define colors for each class
        colors = {0: 'red', 1: '#FFBF00', 2: 'green'}
        label2category = {0: 'Fall', 1: 'Neutral', 2: 'Rise'}

        # Create a scatter plot
        plt.figure(figsize=(8, 6))
        for label in np.unique(labels):
            idx = labels == label
            plt.scatter(embeddings_2d[idx, 0], embeddings_2d[idx, 1],
                        c=colors[label], label=f'Class {label2category[label]}',
                        alpha=0.6, edgecolor='k', linewidth=0.5)

        plt.title('PCA of Embeddings')
        plt.xlabel('PCA 1')
        plt.ylabel('PCA 2')
        plt.legend(loc='best')

        # Return the figure
        fig = plt.gcf()
        return fig

    def compute_nearest_neighbor_accuracy(self, embeddings, labels):
        N, D = embeddings.size()

        # Compute cosine similarity matrix
        similarity_matrix = torch.matmul(embeddings, embeddings.T)  # Shape: (N, N)

        # Exclude self-similarities
        mask = torch.eye(N, device=embeddings.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))

        # Find nearest neighbors
        nn_indices = similarity_matrix.argmax(dim=1)  # Shape: (N,)

        # Get labels of nearest neighbors
        nn_labels = labels[nn_indices]

        # Compute accuracy
        correct = (nn_labels == labels)
        accuracy = correct.float().mean()

        return accuracy

    def compute_label_distributions(self, embeddings, labels, k=5):
        N, D = embeddings.size()

        # Compute cosine similarity matrix
        similarity_matrix = torch.matmul(embeddings, embeddings.T)  # Shape: (N, N)

        # Exclude self-similarities
        mask = torch.eye(N, device=embeddings.device).bool()
        similarity_matrix.masked_fill_(mask, -float('inf'))

        # Find k nearest neighbors
        _, topk_indices = torch.topk(similarity_matrix, k=min(k, N), dim=1)  # Shape: (N, k)

        # Get labels of k nearest neighbors
        neighbor_labels = labels[topk_indices]  # Shape: (N, k)

        # Compute label distributions
        num_classes = self.num_classes
        label_distributions = torch.zeros(N, num_classes, device=embeddings.device)
        for c in range(num_classes):
            label_distributions[:, c] = (neighbor_labels == c).sum(dim=1)

        # Normalize to get probabilities
        label_distributions = label_distributions / k

        return label_distributions

    def compute_global_label_distribution(self, labels):
        num_classes = self.num_classes
        N = labels.size(0)

        label_counts = torch.zeros(num_classes, device=labels.device)
        for c in range(num_classes):
            label_counts[c] = (labels == c).sum()

        global_distribution = label_counts / N

        return global_distribution

    def compute_information_gain(self, embeddings, labels, k=5):
        label_distributions = self.compute_label_distributions(embeddings, labels, k)
        entropies = self.compute_local_entropies(label_distributions)
        mean_entropy = entropies.mean()
        H_max = math.log2(self.num_classes)
        info_gain = H_max - mean_entropy
        return info_gain

    def compute_local_entropies(self, label_distributions):
        epsilon = 1e-10  # Avoid log(0)
        p = label_distributions + epsilon
        entropies = - (p * torch.log2(p)).sum(dim=1)
        return entropies

    def compute_kl_divergence(self, embeddings, labels, k=5):
        label_distributions = self.compute_label_distributions(embeddings, labels, k)
        global_distribution = self.compute_global_label_distribution(labels)

        epsilon = 1e-10
        p = label_distributions + epsilon
        q = global_distribution.unsqueeze(0) + epsilon  # Shape: (1, num_classes)
        kl_divs = (p * (p / q).log2()).sum(dim=1)
        mean_kl_divergence = kl_divs.mean()

        return mean_kl_divergence

    def compute_jsd(self, embeddings, labels, k=5):
        label_distributions = self.compute_label_distributions(embeddings, labels, k)
        global_distribution = self.compute_global_label_distribution(labels)

        epsilon = 1e-10
        p = label_distributions + epsilon
        q = global_distribution.unsqueeze(0) + epsilon
        m = 0.5 * (p + q)

        kl_p_m = (p * (p / m).log2()).sum(dim=1)
        kl_q_m = (q * (q / m).log2()).sum(dim=1)

        jsd = 0.5 * (kl_p_m + kl_q_m)
        mean_jsd = jsd.mean()

        return mean_jsd

    def configure_optimizers(self):
        lr = config.get("training", "lr")
        gamma = config.get("training", "gamma")
        weight_decay = config.get("training", "weight_decay")

        optimizer = Adam(self.parameters(), lr=lr, weight_decay=weight_decay)

        # Set up an ExponentialLR scheduler
        scheduler = ExponentialLR(optimizer, gamma=gamma)

        # Return both the optimizer and the scheduler
        return [optimizer], [scheduler]

class NewClassificationModelBERT(pl.LightningModule):
    def __init__(self):
        super(NewClassificationModelBERT, self).__init__()

        # Get the classification mode from config
        self.classification_mode = config.get("training_classification", "classification_mode")

        if config.get("dataset", "dataset") == "bigdata22":
            self.num_classes = 2  # Number of classes: Rise, Fall
        else:
            self.num_classes = 3  # Number of classes: Rise, Fall, Neutral

        # Initialize models based on the classification_mode
        if self.classification_mode in ["classifier", "both"]:
            # Load the classifier model
            classifier_model_name = config.get("training_classification", "classifier_model_name")
            self.classifier_model = AutoModel.from_pretrained(classifier_model_name)
            self.classifier_hidden_size = self.classifier_model.config.hidden_size
            
            self.classifier_model.train()

        if self.classification_mode in ["projector", "both"]:
            # Load the pretrained projector model from checkpoint
            projector_checkpoint_path = config.get("training_classification", "projector_checkpoint_path")

            # Load the model from the checkpoint
            projector_model = EmbeddingModel.load_from_checkpoint(projector_checkpoint_path)
            
            # Freeze the encoder model
            for param in projector_model.encoder_model.parameters():
                param.requires_grad = False

            # Ensure the projector model is trainable
            for param in projector_model.projector_model.parameters():
                param.requires_grad = True

            # Assign the encoder and projector components
            self.projector_encoder = projector_model.encoder_model
            self.projector_encoder.eval()
            self.projector_projector = projector_model.projector_model
            self.projector_projector.train()
            self.projector_hidden_size = self.projector_projector[-1].out_features  # Assuming last layer defines output size


        # Define the input size for the classifier network
        if self.classification_mode == "classifier":
            input_size = self.classifier_hidden_size
        elif self.classification_mode == "projector":
            input_size = self.projector_hidden_size
        elif self.classification_mode == "both":
            input_size = self.classifier_hidden_size + self.projector_hidden_size
        else:
            raise ValueError(f"Invalid classification_mode: {self.classification_mode}")

        # Define the classifier network with multiple layers
        hidden_sizes = config.get("training_classification", "hidden_sizes", [512, 256])  # Example hidden layer sizes
        dropout_rate = config.get("training_classification", "dropout_rate", 0.1)
        self.classifier = self._build_classifier(input_size, hidden_sizes, dropout_rate)

        # Define the loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Initialize metrics
        self.train_accuracy = MulticlassAccuracy(num_classes=self.num_classes)
        self.val_accuracy = MulticlassAccuracy(num_classes=self.num_classes)
        self.test_accuracy = MulticlassAccuracy(num_classes=self.num_classes)

        self.train_f1 = MulticlassF1Score(num_classes=self.num_classes, average='macro')
        self.val_f1 = MulticlassF1Score(num_classes=self.num_classes, average='macro')
        self.test_f1 = MulticlassF1Score(num_classes=self.num_classes, average='macro')

    def _build_classifier(self, input_size, hidden_sizes, dropout_rate):
        layers = []
        in_features = input_size

        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(in_features, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            in_features = hidden_size

        layers.append(nn.Linear(in_features, self.num_classes))
        return nn.Sequential(*layers)

    def forward(self, base_input_ids, base_attention_mask, base_token_type_ids=None):
        embeddings = []

        # Get embeddings from the classifier model
        if self.classification_mode in ["classifier", "both"]:
            classifier_outputs = self.classifier_model(
                input_ids=base_input_ids,
                attention_mask=base_attention_mask,
                token_type_ids=base_token_type_ids,
                output_hidden_states=False,
                return_dict=True
            )
            if hasattr(classifier_outputs, 'pooler_output') and classifier_outputs.pooler_output is not None:
                classifier_embedding = classifier_outputs.pooler_output
            else:
                classifier_embedding = classifier_outputs.last_hidden_state[:, 0, :]
            embeddings.append(classifier_embedding)

        # Get embeddings from the projector model
        if self.classification_mode in ["projector", "both"]:
            encoder_outputs = self.projector_encoder(
                input_ids=base_input_ids,
                attention_mask=base_attention_mask,
                token_type_ids=base_token_type_ids,
                output_hidden_states=False,
                return_dict=True
            )
            if hasattr(encoder_outputs, 'pooler_output') and encoder_outputs.pooler_output is not None:
                pooled_output = encoder_outputs.pooler_output
            else:
                pooled_output = encoder_outputs.last_hidden_state.mean(dim=1)

            projector_embedding = self.projector_projector(pooled_output)
            #projector_embedding = F.normalize(projector_embedding, p=2, dim=-1)
            embeddings.append(projector_embedding)

        # Concatenate embeddings if necessary
        if len(embeddings) == 1:
            combined_embedding = embeddings[0]
        else:
            combined_embedding = torch.cat(embeddings, dim=-1)

        # Pass through the classifier network to get logits
        logits = self.classifier(combined_embedding)

        return logits

    def training_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch.get('base_token_type_ids', None)
        labels = batch['anchor_label']

        # Forward pass
        logits = self.forward(
            base_input_ids=base_input_ids,
            base_attention_mask=base_attention_mask,
            base_token_type_ids=base_token_type_ids
        )

        # Compute loss
        loss = self.loss_fn(logits, labels)

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Log metrics
        acc = self.train_accuracy(preds, labels)
        f1 = self.train_f1(preds, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_f1', f1, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch.get('base_token_type_ids', None)
        labels = batch['anchor_label']

        # Forward pass
        logits = self.forward(
            base_input_ids=base_input_ids,
            base_attention_mask=base_attention_mask,
            base_token_type_ids=base_token_type_ids
        )

        # Compute loss
        loss = self.loss_fn(logits, labels)

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Log metrics
        acc = self.val_accuracy(preds, labels)
        f1 = self.val_f1(preds, labels)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch.get('base_token_type_ids', None)
        labels = batch['anchor_label']

        # Forward pass
        logits = self.forward(
            base_input_ids=base_input_ids,
            base_attention_mask=base_attention_mask,
            base_token_type_ids=base_token_type_ids
        )

        # Compute loss
        loss = self.loss_fn(logits, labels)

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Log metrics
        acc = self.test_accuracy(preds, labels)
        f1 = self.test_f1(preds, labels)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_f1', f1, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        lr = config.get("training_classification", "lr")
        weight_decay = config.get("training_classification", "weight_decay")
        optimizer = AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)

        # Optional scheduler
        total_steps = config.get("training_classification", "total_steps")
        warmup_steps = config.get("training_classification", "warmup_steps")
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]


class NewClassificationModel(pl.LightningModule):
    def __init__(self):
        super(NewClassificationModel, self).__init__()

        # Get the classification mode from config
        self.classification_mode = config.get("training_classification", "classification_mode")

        # Set up the prompt based on the dataset
        dataset_name = config.get("dataset", "dataset")
        if dataset_name == "bigdata22":
            self.num_classes = 2  # Number of classes: Rise, Fall
            self.prompt = "Read these lists of financial tweets and predict whether the market will Rise or Fall. "
        else:
            self.num_classes = 3  # Number of classes: Good, Bad, Neutral
            self.prompt = "Read these movie reviews and predict whether the movie was good, bad, or neutral. "

        # Initialize BERT tokenizer for decoding
        bert_model_name = "bert-base-cased"
        self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

        # Initialize GPT-2 tokenizer and model
        gpt_model_name = 'gpt2'  # You can choose 'gpt2', 'distilgpt2', or another GPT-2 variant
        self.gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
        self.gpt_tokenizer.pad_token = self.gpt_tokenizer.eos_token
        self.gpt_model = AutoModel.from_pretrained(gpt_model_name)
        self.gpt_hidden_size = self.gpt_model.config.hidden_size

        self.gpt_model.train()

        # If using the projector model
        if self.classification_mode in ["projector", "both"]:
            # Load the pretrained projector model from checkpoint
            projector_checkpoint_path = config.get("training_classification", "projector_checkpoint_path")

            # Load the model from the checkpoint
            projector_model = EmbeddingModel.load_from_checkpoint(projector_checkpoint_path)

            # Freeze the encoder model
            for param in projector_model.encoder_model.parameters():
                param.requires_grad = False

            # Ensure the projector model is trainable
            for param in projector_model.projector_model.parameters():
                param.requires_grad = True

            # Assign the encoder and projector components
            self.projector_encoder = projector_model.encoder_model
            self.projector_encoder.eval()
            self.projector_projector = projector_model.projector_model
            self.projector_projector.train()
            self.projector_hidden_size = self.projector_projector[-1].out_features  # Assuming last layer defines output size

        # Define the input size for the classifier network
        if self.classification_mode == "classifier":
            input_size = self.gpt_hidden_size
        elif self.classification_mode == "projector":
            input_size = self.projector_hidden_size
        elif self.classification_mode == "both":
            input_size = self.gpt_hidden_size + self.projector_hidden_size
        else:
            raise ValueError(f"Invalid classification_mode: {self.classification_mode}")

        # Define the classifier network with multiple layers
        hidden_sizes = config.get("training_classification", "hidden_sizes", [512, 256])  # Example hidden layer sizes
        dropout_rate = config.get("training_classification", "dropout_rate", 0.1)
        self.classifier = self._build_classifier(input_size, hidden_sizes, dropout_rate)

        # Define the loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Initialize metrics
        self.train_accuracy = MulticlassAccuracy(num_classes=self.num_classes)
        self.val_accuracy = MulticlassAccuracy(num_classes=self.num_classes)
        self.test_accuracy = MulticlassAccuracy(num_classes=self.num_classes)

        self.train_f1 = MulticlassF1Score(num_classes=self.num_classes, average='macro')
        self.val_f1 = MulticlassF1Score(num_classes=self.num_classes, average='macro')
        self.test_f1 = MulticlassF1Score(num_classes=self.num_classes, average='macro')

    def _build_classifier(self, input_size, hidden_sizes, dropout_rate):
        layers = []
        in_features = input_size

        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(in_features, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            in_features = hidden_size

        layers.append(nn.Linear(in_features, self.num_classes))
        return nn.Sequential(*layers)

    def forward(self, base_input_ids, base_attention_mask, base_token_type_ids=None):
        embeddings = []

        # Move base_input_ids to CPU for decoding if necessary
        base_input_ids_cpu = base_input_ids.cpu()

        # Decode BERT tokens back to text
        texts = []
        for input_id in base_input_ids_cpu:
            # Decode input_ids to text
            text = self.bert_tokenizer.decode(input_id, skip_special_tokens=True)
            texts.append(self.prompt + text)

        # Tokenize combined text with GPT-2 tokenizer
        encoding = self.gpt_tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)

        # Get embeddings from the GPT-2 model
        if self.classification_mode in ["classifier", "both"]:
            gpt_outputs = self.gpt_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False,
                return_dict=True
            )

            # Use the last hidden state of the last token
            gpt_embedding = gpt_outputs.last_hidden_state[:, -1, :]
            embeddings.append(gpt_embedding)

        # Get embeddings from the projector model
        if self.classification_mode in ["projector", "both"]:
            with torch.no_grad():  # Ensure projector encoder is not trained
                encoder_outputs = self.projector_encoder(
                    input_ids=base_input_ids,
                    attention_mask=base_attention_mask,
                    token_type_ids=base_token_type_ids,
                    output_hidden_states=False,
                    return_dict=True
                )

                # Get the pooled output (e.g., mean of the hidden states)
                pooled_output = encoder_outputs.last_hidden_state.mean(dim=1)

            # Pass through the projector
            projector_embedding = self.projector_projector(pooled_output)
            projector_embedding = F.normalize(projector_embedding, p=2, dim=-1)
            embeddings.append(projector_embedding)

        # Concatenate embeddings if necessary
        if len(embeddings) == 1:
            combined_embedding = embeddings[0]
        else:
            combined_embedding = torch.cat(embeddings, dim=-1)

        # Pass through the classifier network to get logits
        logits = self.classifier(combined_embedding)

        return logits

    def training_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch.get('base_token_type_ids', None)
        labels = batch['anchor_label']

        # Forward pass
        logits = self.forward(
            base_input_ids=base_input_ids,
            base_attention_mask=base_attention_mask,
            base_token_type_ids=base_token_type_ids
        )

        # Compute loss
        loss = self.loss_fn(logits, labels)

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Log metrics
        acc = self.train_accuracy(preds, labels)
        f1 = self.train_f1(preds, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_f1', f1, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch.get('base_token_type_ids', None)
        labels = batch['anchor_label']

        # Forward pass
        logits = self.forward(
            base_input_ids=base_input_ids,
            base_attention_mask=base_attention_mask,
            base_token_type_ids=base_token_type_ids
        )

        # Compute loss
        loss = self.loss_fn(logits, labels)

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Log metrics
        acc = self.val_accuracy(preds, labels)
        f1 = self.val_f1(preds, labels)
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        base_input_ids = batch['base_input_ids']
        base_attention_mask = batch['base_attention_mask']
        base_token_type_ids = batch.get('base_token_type_ids', None)
        labels = batch['anchor_label']

        # Forward pass
        logits = self.forward(
            base_input_ids=base_input_ids,
            base_attention_mask=base_attention_mask,
            base_token_type_ids=base_token_type_ids
        )

        # Compute loss
        loss = self.loss_fn(logits, labels)

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Log metrics
        acc = self.test_accuracy(preds, labels)
        f1 = self.test_f1(preds, labels)
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_f1', f1, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        lr = config.get("training_classification", "lr")
        weight_decay = config.get("training_classification", "weight_decay")
        optimizer = AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)

        # Optional scheduler
        total_steps = config.get("training_classification", "total_steps")
        warmup_steps = config.get("training_classification", "warmup_steps")
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
