
import torch
import torch.nn as nn
import random
from tqdm import tqdm  
import torch.nn.functional as F


import os
import torch.distributed as dist
from torch.utils.data import DistributedSampler, DataLoader


def label_to_real_embedding_dict(real_embeddings, labels):
    """
    Creates a dictionary mapping each unique label to its corresponding real embedding.
    
    Args:
        real_embeddings (torch.Tensor): N*D tensor of real embedding vectors.
        labels (torch.Tensor): N tensor of integer class labels.
    
    Returns:
        label_to_embedding (dict): Dictionary mapping labels to real embeddings.
    """
    label_to_embedding = {}
    unique_labels = labels.unique()
    
    for label in unique_labels:
        label_to_embedding[label.item()] = real_embeddings[labels == label].mean(dim=0)
    
    return label_to_embedding

def compute_cosine_similarity(reconstructed_embedding, real_embeddings):
    """
    Computes cosine similarity between a reconstructed embedding and all real embeddings.
    
    Args:
        reconstructed_embedding (torch.Tensor): 1*D tensor of a single reconstructed embedding.
        real_embeddings (torch.Tensor): C*D tensor of real embedding vectors (C is number of classes).
    
    Returns:
        similarity_scores (torch.Tensor): 1*C tensor of cosine similarity scores for each class.
    """
    # Normalize embeddings for cosine similarity
    reconstructed_embedding = F.normalize(reconstructed_embedding, p=2, dim=0)
    real_embeddings = F.normalize(real_embeddings, p=2, dim=1)
    
    # Compute cosine similarity
    similarity_scores = torch.matmul(reconstructed_embedding, real_embeddings.T)
    
    return similarity_scores

def custom_loss_function(reconstructed_embeddings, real_embeddings, labels):
    """
    Custom loss function that computes cross-entropy loss based on cosine similarity.
    
    Args:
        reconstructed_embeddings (torch.Tensor): N*D tensor of reconstructed embedding vectors.
        real_embeddings (torch.Tensor): N*D tensor of real embedding vectors.
        labels (torch.Tensor): N tensor of integer class labels.
    
    Returns:
        loss (torch.Tensor): The cross-entropy loss.
    """
    # Step 1: Create dictionary mapping labels to their real embeddings
    label_to_embedding = label_to_real_embedding_dict(real_embeddings, labels)
    
    # Step 2: Prepare a tensor for real class embeddings (C*D, where C is the number of unique classes)
    class_labels = torch.tensor(list(label_to_embedding.keys()), device=reconstructed_embeddings.device)
    class_real_embeddings = torch.stack([label_to_embedding[label.item()] for label in class_labels])
    
    # Step 3: Calculate cosine similarity for each reconstructed embedding
    all_predicted_probabilities = []
    for reconstructed_embedding in reconstructed_embeddings:
        # Compute cosine similarity with all class embeddings
        similarity_scores = compute_cosine_similarity(reconstructed_embedding, class_real_embeddings)
        
        # Step 4: Apply softmax to get predicted probabilities
        predicted_probabilities = F.softmax(similarity_scores, dim=-1)
        all_predicted_probabilities.append(predicted_probabilities)
    
    all_predicted_probabilities = torch.stack(all_predicted_probabilities)  # Shape: N*C
    
    # Step 5: Compute cross-entropy loss
    # Convert labels to indices in the range [0, C-1] for cross-entropy
    target_class_indices = torch.tensor([class_labels.tolist().index(label.item()) for label in labels], device=labels.device)
    
    loss = F.cross_entropy(all_predicted_probabilities, target_class_indices)
    
    return loss


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None, **kwargs):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        if len(features) < 2:
            return 0
        device = features.device

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        logits = torch.clamp(logits, min=-100, max=100)


        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        # Add epsilon to avoid log(0)
        epsilon = 1e-12
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + epsilon)

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan] 
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss
    

class LatentTrainer:
    def __init__(self, model, dataloader, optimizer, criterion, scheduler, device="cuda", world_size=1, rank=0, distributed=False,fix_batches=False,**kwargs):
        self.model = model
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.device = device
        self.world_size = world_size
        self.rank = rank
        self.distributed = distributed
        self.fix_batches = fix_batches

    def save_model(self, save_path):
        """
        Save the trained model to the specified path.

        Args:
            save_path (str): The file path where the model should be saved.
        """
        # Save the model's state dictionary
        torch.save(self.model.state_dict(), save_path)
        print(f"Model saved to {save_path}")

class ContrastiveTrainer(LatentTrainer):
    def __init__(self, model, dataloader, optimizer, criterion, scheduler,device="cuda", kl_reg_factor = 0,mask_ratio=0.15,num_views=2,world_size=1, rank=0, distributed=False, fix_batches=False):
        super().__init__(model, dataloader, optimizer, criterion, scheduler, device,world_size, rank, distributed, fix_batches)
        self.kl_reg_factor = kl_reg_factor
        self.mask_ratio = mask_ratio
        self.num_views = num_views


    def generate_views(self, batch, num_views=2, drop_ratio=0.25):
        """
        Generate multiple views of the input batch by replacing sequence steps with -1000.

        Args:
            batch (Tensor): Input tensor of shape [batch_size, seq_len, embedding_dim].
            num_views (int): Number of views to generate for each sequence.
            drop_ratio (float): The ratio of sequence steps to drop (0 < drop_ratio < 1).

        Returns:
            Tensor: Tensor of shape [batch_size, num_views, seq_len, embedding_dim].
        """
        batch_size, seq_len, embedding_dim = batch.shape
        views = []
        step_pairs = seq_len // 2  # Number of column name-value pairs
        steps_to_drop = int(step_pairs * drop_ratio)  # Calculate number of pairs to drop

        for _ in range(num_views):
            # Copy the original batch to preserve shape
            dropped_batch = batch.clone()

            for i in range(batch_size):
                # Randomly select pairs to drop
                pairs_to_drop = random.sample(range(step_pairs), k=steps_to_drop)
                steps_to_replace = []
                for pair_idx in pairs_to_drop:
                    steps_to_replace.extend([2 * pair_idx, 2 * pair_idx + 1])  # Select both the column name and value

                # Replace the selected steps with -1000
                dropped_batch[i, steps_to_replace, :] = -1000

            views.append(dropped_batch)

        # Stack the views along a new dimension
        views_tensor = torch.stack(views, dim=1)  # Shape: [batch_size, num_views, seq_len, embedding_dim]
        return views_tensor
    
    def compute_kl_divergence(self, features, epsilon=1e-6):
        """
        Compute the KL divergence between the features and a standard normal distribution.

        Args:
            features (Tensor): Input tensor of shape [batch_size, num_views, feature_dim].
            epsilon (float): A small constant to ensure numerical stability in the variance.

        Returns:
            Tensor: Scalar tensor representing the KL divergence loss.
        """
        # Calculate the mean and variance of the features
        mean = features.mean(dim=1)
        var = features.var(dim=1)

        # Add epsilon to the variance for numerical stability and calculate logvar
        var = var + epsilon
        logvar = var.log()

        # Compute the KL divergence between N(mean, var) and N(0, 1)
        kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - var)
        
        # Normalize by the number of features (not batch size)
        kl_div = kl_div / features.shape[-1]  # Normalize by feature dimension
        
        return kl_div
    
    def train(self, num_epochs=10, checkpoint_dir="checkpoints", save_interval=30):
        """
        Train the model using self-supervised contrastive learning with distributed training.

        Args:
            num_epochs (int): Number of epochs to train.
            checkpoint_dir (str): Directory to save model checkpoints.
            loss_file (str): File path to save the avg_loss after each epoch.
            save_interval (int): The interval (in epochs) to save the model checkpoint.
        """
        print("Training started!")
        # Move model to device
        self.model.to(self.device)
        self.model.train()

        # Create directory for saving checkpoints if it doesn't exist
        if self.rank == 0 and not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        # Initialize DistributedGroupedBatchSampler if using distributed training
        if self.distributed:
            sampler = DistributedSampler(self.dataloader.dataset, num_replicas=self.world_size, rank=self.rank)
            self.dataloader = DataLoader(self.dataloader.dataset, sampler=sampler, batch_size=1,num_workers=4) # 
            
        dataloader = self.dataloader 

        # Initialize variables to track the lowest loss, avg_loss history, and best model state
        best_loss = float('inf')
        best_epoch = -1
        avg_loss_history = []
        best_model_state = None

        for epoch in range(num_epochs):
            total_loss = 0.0
            num_batches = 0

            if not self.fix_batches:
                # Shuffle batches for more diverse comparison
                dataloader.dataset.shuffle_batches()

            # Set epoch for DistributedSampler (for proper shuffling between epochs)
            if self.distributed:
                dataloader.sampler.set_epoch(epoch)

            # New progress bar for each epoch, showing batch loss
            if self.rank == 0:
                batch_progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1} Progress")

            for batch, _, dtype, meta in dataloader:
                #print(batch.shape)
                batch = batch.squeeze().to(self.device)  # Move batch to device and remove first dim which is always 1
                dtype, meta = dtype.to(self.device).long(), meta.to(self.device).float()

                # Generate multiple views of the batch
                views = self.generate_views(batch, self.num_views, self.mask_ratio)
                batch_size, num_views, num_columns_times_2, encode_dim = views.shape

                # Forward pass through the model
                # Attention mask ignores masked value/name embeddings (-1000)
                batched_view = views.view(batch_size * num_views, num_columns_times_2, encode_dim)
                attention_mask = torch.all(batched_view != -1000, dim=-1).long().to(self.device)
                features = self.model(batched_view, attention_mask, dtype, meta)

                # Calculate the contrastive loss
                reshaped_feature = features.view(batch_size, num_views, -1)
                loss = self.criterion(reshaped_feature)

                # Add KL divergence loss if kl_reg is enabled
                if self.kl_reg_factor > 0:
                    kl_loss = self.compute_kl_divergence(reshaped_feature)
                    loss += kl_loss * self.kl_reg_factor

                # Backward pass and optimization
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # Update the learning rate using the scheduler
                self.scheduler.step()

                total_loss += loss.item()
                num_batches += 1

                # Update progress bar with the current batch loss
                if self.rank == 0:
                    batch_progress_bar.set_postfix(batch_loss=loss.item())
                    batch_progress_bar.update(1)

            # Close the batch progress bar at the end of the epoch
            if self.rank == 0:
                batch_progress_bar.close()

            # Average loss for the epoch
            avg_loss = total_loss / num_batches
            avg_loss_history.append(avg_loss)

            # Optionally print or log the learning rate
            current_lr = self.optimizer.param_groups[0]['lr']

            # Update the best model if current epoch's avg_loss is the best
            if avg_loss < best_loss:
                best_loss = avg_loss
                best_epoch = epoch + 1
                best_model_state = self.model.state_dict()

            # Only print at the end of each epoch (rank 0)
            if self.rank == 0:
                print(f"Epoch {epoch + 1} completed. Avg Loss: {avg_loss:.4f}, Best Epoch Loss: {best_loss:.4f}. Current lr: {current_lr:.4f}")

            # Save the model at intervals (only rank 0)
            if self.rank == 0 and (epoch + 1) % save_interval == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pt")
                torch.save(self.model.state_dict(), checkpoint_path)
                print(f"Model checkpoint saved at {checkpoint_path}")

            # Barrier to synchronize processes (optional, but useful for debugging)
            if self.distributed:
                dist.barrier()

        # After training, restore the best model state (on all processes)
        if best_model_state is not None:
            self.model.load_state_dict(best_model_state)

        # Save final model if it's not already saved (only rank 0)
        if self.rank == 0 and best_model_state is not None:
            final_model_path = os.path.join(checkpoint_dir, f"best_model_epoch_{best_epoch}.pt")
            torch.save(best_model_state, final_model_path)
            print(f"Best model saved at {final_model_path}")

        # Print final results on rank 0
        if self.rank == 0:
            print(f"Total training epochs: {num_epochs}. Final training loss: {avg_loss:.4f}. Final learning rate: {current_lr:.6f}.")
            print(f"Best training loss: {best_loss:.4f}. Best epoch: {best_epoch}")
            print(f"Final learning rate is: {current_lr:.4f}")




          