import torch
import torch.nn as nn
import random
from tqdm import tqdm  
import torch.nn.functional as F


import os
import torch.distributed as dist
from torch.utils.data import DistributedSampler, DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import ReduceLROnPlateau

from tableLatent.feature.VAE import compute_loss
from tableLatent.perceive.contrastTrainer import LatentTrainer

import torch

import time

def is_main_process():
    """Check if the current process is the main one (rank == 0)."""
    return not dist.is_initialized() or dist.get_rank() == 0

def process_column_embeddings(embeddings, labels):
    """
    Normalize category labels and stack unique embeddings for each column.

    Args:
    - embeddings (torch.Tensor): Tensor of shape (B, num_cols, D) representing embeddings.
    - labels (torch.Tensor): Tensor of shape (B, num_cols) representing column category labels.

    Returns:
    - normalized_labels (torch.Tensor): Tensor of shape (B, num_cols) with normalized labels.
    - unique_embedding_list (list): List of tensors where each tensor has shape (num_unique_levels, D) for each column.
    """
    B, num_cols, D = embeddings.shape
    normalized_labels = torch.zeros_like(labels)
    unique_embedding_list = []

    # Process each column independently
    for col in range(num_cols):
        # Get the column's labels and embeddings
        col_labels = labels[:, col]
        col_embeddings = embeddings[:, col, :]  # Shape (B, D)

        # Find unique labels and their corresponding indices
        unique_labels, inverse_indices = torch.unique(col_labels, return_inverse=True)
        
        # Normalize labels to range [0, num_unique_levels - 1]
        normalized_col_labels = inverse_indices  # Already normalized by torch.unique
        normalized_labels[:, col] = normalized_col_labels

        # Find unique embeddings corresponding to the unique labels
        unique_col_embeddings = torch.stack([col_embeddings[col_labels == label].unique(dim=0)[0] 
                                             for label in unique_labels])

        # Ensure the unique embeddings are stacked in the order of normalized labels
        unique_embedding_list.append(unique_col_embeddings)

    return normalized_labels, unique_embedding_list


class VAETrainer:
    def __init__(self, model, dataloader, test_dataloader, optimizer, criterion, scheduler, device="cuda", kl_reg=1e-4, distributed=False, rank=0, world_size=1):
        self.model = model
        self.dataloader = dataloader
        self.test_dataloader = test_dataloader
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.device = device
        self.kl_reg = kl_reg
        self.distributed = distributed
        self.rank = rank
        self.world_size = world_size

        # Handle distributed setup
        if distributed:
            self.setup_distributed()

        # Move model to device and wrap with DDP if needed
        self.setup_model()

    def setup_distributed(self):
        """
        Initialize the process group for distributed training.
        """
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group(backend='nccl', rank=self.rank, world_size=self.world_size)

        # Modify dataloaders for distributed training
        self.dataloader = self.get_distributed_dataloader(self.dataloader, self.rank, self.world_size)
        self.test_dataloader = self.get_distributed_dataloader(self.test_dataloader, self.rank, self.world_size)

    def get_distributed_dataloader(self, loader, rank, world_size):
        """
        Wrap the DataLoader with a DistributedSampler for distributed training.
        """
        sampler = DistributedSampler(loader.dataset, num_replicas=world_size, rank=rank)
        return DataLoader(loader.dataset, sampler=sampler, batch_size=1, num_workers=4)

    def setup_model(self):
        """
        Move model to device and wrap with DistributedDataParallel if using distributed training.
        """
        if self.distributed:
            torch.cuda.set_device(self.rank)
            self.model = self.model.to(self.rank)
            self.model = DDP(self.model, device_ids=[self.rank])
        else:
            self.model = self.model.to(self.device)

    def save_checkpoint(self, checkpoint_path, encoder_params, decoder_params):
        """
        Save model checkpoint along with additional state information.
        """
        if self.distributed:
            model_state = self.model.module.state_dict()  # Use the underlying model in DDP
        else:
            model_state = self.model.state_dict()

        checkpoint = {
            'vae_state': model_state,
            'transformers_per_table': self.dataloader.dataset.transformers_dict,
            'encoder_params': encoder_params,
            'decoder_params': decoder_params
        }

        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

    def load_checkpoint(self, checkpoint_path):
        """
        Load model checkpoint from file.
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        # Load model state
        if self.distributed:
            self.model.module.load_state_dict(checkpoint['vae_state'])
        else:
            self.model.load_state_dict(checkpoint['vae_state'])
        print(f"Model loaded from {checkpoint_path}")

        # Additional parameters (if needed)
        return checkpoint['encoder_params'], checkpoint['decoder_params']

    def cleanup_distributed(self):
        """
        Cleanup the process group for distributed training.
        """
        dist.destroy_process_group()
        
    def compute_kl_divergence(self, features, epsilon=1e-6):
        """
        Compute the KL divergence between the features and a standard normal distribution.

        Args:
            features (Tensor): Input tensor of shape [batch_size, num_views, feature_dim].
            epsilon (float): A small constant to ensure numerical stability in the variance.

        Returns:
            Tensor: Scalar tensor representing the KL divergence loss.
        """
        # Calculate the mean and variance of the features
        mean = features.mean(dim=1)
        var = features.var(dim=1)

        # Add epsilon to the variance for numerical stability and calculate logvar
        var = var + epsilon
        logvar = var.log()

        # Compute the KL divergence between N(mean, var) and N(0, 1)
        kl_div = -0.5 * torch.sum(1 + logvar - mean.pow(2) - var)
        
        # Normalize by the number of features (not batch size)
        kl_div = kl_div / features.shape[-1]  # Normalize by feature dimension
        
        return kl_div
    
    def get_model(self):
        return self.model

    def _run_batch(self, batch_emb, batch_label, categories, meta, unq_embeddings, model, device, is_train=True):
        """
        Helper function to run forward pass and compute losses.
        
        Args:
        - batch_emb: Batch of embeddings.
        - batch_label: Corresponding labels.
        - categories: Categories tensor indicating the column type.
        - meta: Metadata embeddings.
        - unq_embeddings: Unique embeddings tensor.
        - model: The model being trained.
        - device: The device (CPU/GPU).
        - is_train: Whether this is a training step or not.
        
        Returns:
        - loss_mse: MSE loss for numerical columns.
        - loss_ce: Cross-entropy loss for categorical columns.
        - loss_kl: KL divergence loss.
        - accuracy: Training accuracy.
        """
        model.train() if is_train else model.eval()

        batch_emb = batch_emb.squeeze().to(device)
        batch_length = batch_emb.shape[0]
        attention_mask = torch.all(batch_emb != -1000, dim=-1).long().to(device)
        column_names_emb = batch_emb[:, ::2, :]

        categories = categories.squeeze()
        dtype = (categories == 0).long()
        cat_idx, num_idx = torch.where(dtype == 0)[0], torch.where(dtype == 1)[0]
        dtype, meta = dtype.to(device).long(), meta.to(device).float()

        batch_label = batch_label.squeeze()
        batch_num = batch_label[:, num_idx].to(device)
        batch_cat = batch_label[:, cat_idx].to(device).long()

        metadata_emb = meta.repeat(batch_length, 1, 1)
        
        unique_embedding_list = []
        unq_embeddings = unq_embeddings.squeeze()
        cur_cat_idx = 0
        for column_category_count in categories:
            column_category_count = int(column_category_count.item())
            if column_category_count == 0:
                continue
            unique_embedding_list.append(unq_embeddings[cur_cat_idx:cur_cat_idx+column_category_count].clone().to(device))
            cur_cat_idx += column_category_count

        if self.distributed:
            batch_latent = model.module.encode(batch_emb, attention_mask, dtype, meta).squeeze()
            Recon_X_num, Recon_X_cat = model.module.decode(column_names_emb, metadata_emb, batch_latent, dtype, unique_embedding_list)
        else:
            batch_latent = model.encode(batch_emb, attention_mask, dtype, meta).squeeze()
            Recon_X_num, Recon_X_cat = model.decode(column_names_emb, metadata_emb, batch_latent, dtype, unique_embedding_list)

        if not is_train and False:
            print("*"*40+"\n"+"Recon_cat in val:")
            print("column_names_emb.shape,",column_names_emb.shape)
            #print("column_names_emb:",column_names_emb[0,:,:5])
            print("metadata_emb.shape",metadata_emb.shape)
            #print("metadata_emb:",metadata_emb[:5])
            print("dtype:",dtype)
            print("unique_embedding_list:",unique_embedding_list[0][:,:5])
            print("Recon_X_cat:",Recon_X_cat[0])
            print("batch_cat:",batch_cat[:,0])
            print("*"*40)

        # Compute losses
        loss_mse, loss_ce, _, accuracy = self.criterion(batch_num, batch_cat, Recon_X_num, Recon_X_cat)
        loss_kl = self.compute_kl_divergence(batch_latent) * self.kl_reg

        total_loss = loss_mse + loss_ce + loss_kl
        
        # Backpropagation for training
        if is_train:
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()

        return loss_mse.item(), loss_ce.item(), loss_kl.item(), accuracy.item()

    def _epoch_loop(self, loader, model, device, is_train=True):
        """
        Run an epoch for training or validation with a progress bar.

        Args:
        - loader: DataLoader for train/validation data.
        - model: Model being trained/evaluated.
        - device: CPU or GPU.
        - is_train: Flag to indicate training or validation.

        Returns:
        - avg_mse: Average MSE loss for the epoch.
        - avg_ce: Average Cross-Entropy loss for the epoch.
        - avg_kl: Average KL divergence loss for the epoch.
        - avg_accuracy: Average accuracy for the epoch.
        """
        epoch_mse, epoch_ce, epoch_kl, epoch_accuracy = 0.0, 0.0, 0.0, 0.0
        mse_count, ce_count = 0, 0

        # Initialize tqdm progress bar
        pbar = tqdm(loader, total=len(loader), desc="Training" if is_train else "Validation", leave=False)
        
        for batch_emb, batch_label, categories, meta, unq_embeddings in pbar:
            # Run batch forward pass
            loss_mse, loss_ce, loss_kl, accuracy = self._run_batch(
                batch_emb, batch_label, categories, meta, unq_embeddings, model, device, is_train
            )

            # Accumulate metrics
            batch_size = batch_emb.shape[0]

            # Only accumulate MSE loss if the batch contains numerical columns (categories == 0)
            if (categories == 0).any():
                epoch_mse += loss_mse * batch_size
                mse_count += batch_size  # Count only for batches that contribute to MSE

            # Only accumulate CE loss if the batch contains categorical columns (categories > 0)
            if (categories > 0).any():
                epoch_ce += loss_ce * batch_size
                epoch_accuracy += accuracy * batch_size
                ce_count += batch_size  # Count only for batches that contribute to CE

            # KL loss are always accumulated
            epoch_kl += loss_kl * batch_size

            # Update progress bar with batch losses
            pbar.set_postfix({
                "MSE": f"{loss_mse:.4f}" if (categories == 0).any() else "N/A",
                "CE": f"{loss_ce:.4f}" if (categories > 0).any() else "N/A",
                "KL": f"{loss_kl:.4f}",
                "Acc": f"{accuracy:.4f}" if (categories > 0).any() else "N/A",
            })

        # Average losses and accuracy
        avg_mse = epoch_mse / mse_count if mse_count > 0 else 0.0  # Avoid division by zero
        avg_ce = epoch_ce / ce_count if ce_count > 0 else 0.0       # Avoid division by zero
        avg_accuracy = epoch_accuracy / ce_count if ce_count > 0 else 0.0  # Same for accuracy

        avg_kl = epoch_kl / (mse_count + ce_count)  # KL loss is averaged over total batches

        # Close progress bar
        pbar.close()

        return avg_mse, avg_ce, avg_kl, avg_accuracy

    def train(self, num_epochs=1000, checkpoint_path="best_model.pth"):
        device = self.device
        model = self.model.to(device)

        best_val_loss = float('inf')
        best_model_state = None  # To store the best model state
        start_time = time.time()

        for epoch in range(num_epochs):
            if is_main_process():
                print(f"Epoch {epoch+1}/{num_epochs}")

            # Training step
            train_mse, train_ce, train_kl, train_acc = self._epoch_loop(
                self.dataloader, model, device, is_train=True
            )

            # Validation step
            val_mse, val_ce, val_kl, val_acc = self._epoch_loop(
                self.test_dataloader, model, device, is_train=False
            )

            # Compute total validation loss for comparison
            val_loss = val_mse + val_ce + val_kl

            # Learning rate scheduler
            self.scheduler.step(val_loss)

            # Logging results (only main process should print)
            if is_main_process():
                print(f'Epoch {epoch+1}: Train MSE: {train_mse:.6f}, Train CE: {train_ce:.6f}, Train KL: {train_kl:.6f}, '
                    f'Val MSE: {val_mse:.6f}, Val CE: {val_ce:.6f}, Val KL: {val_kl:.6f}, '
                    f'Train ACC: {train_acc:.6f}, Val ACC: {val_acc:.6f}')

            # Check for new best validation loss and save the model (only in the main process)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()  # Save the best model's state
                if is_main_process():
                    torch.save(best_model_state, checkpoint_path)  # Save checkpoint to disk
                    print(f"Best model saved with validation loss: {best_val_loss:.6f}")

        end_time = time.time()
        if is_main_process():
            print(f'Training completed in {(end_time - start_time) / 60:.4f} minutes')

        # Restore model to the best validation loss checkpoint (only in main process)
        if best_model_state is not None and is_main_process():
            model.load_state_dict(torch.load(checkpoint_path))
            print(f"Model restored to best validation loss state: {best_val_loss:.6f}")



class DecoderTrainer(LatentTrainer):
    _ACCEPTED_DECODER_TYPES = ['categorical', 'numerical']
    def __init__(self, model,aggregator,dataloader, test_dataloader, optimizer, criterion, scheduler,device="cuda", decoder_type='categorical', eval_interval=10, pct_of_col_to_train=1):
        super().__init__(model, dataloader, optimizer, criterion, scheduler, device)
        self.aggregator = aggregator
        assert decoder_type in self._ACCEPTED_DECODER_TYPES, f"Decoder type {decoder_type} not supported! Available ones: {self._ACCEPTED_DECODER_TYPES}."
        self.decoder_type = decoder_type
        self.test_dataloader = test_dataloader
        self.eval_interval = eval_interval
        self.pct_of_col_to_train = pct_of_col_to_train # Proprotion of columns to be trained for each batch.

    def train(self, num_epochs=10):
        """
        Train the model using self-supervised contrastive learning.

        Args:
            num_epochs (int): Number of epochs to train.
        """
        self.model.to(self.device)
        self.aggregator.to(self.device)
        self.aggregator.eval() # Freeze trained aggregator

        # Calculate the total number of batches across all epochs
        total_batches = len(self.dataloader) * num_epochs
        progress_bar = tqdm(total=num_epochs, desc="Training Progress")

        best_loss = float('inf')
        best_model_state = None

        for epoch in range(num_epochs):
            avg_loss = self._run_epoch(self.dataloader)

            if epoch % self.eval_interval == 0:
                test_loss = self._run_epoch(self.test_dataloader, eval=True)

            # Update progress bar with the latest batch loss
            progress_bar.set_postfix(epoch=epoch + 1, loss=f"{avg_loss:.4f}", test_loss=f"{test_loss:.4f}" if test_loss else 'N/A')
            progress_bar.update(1)  # Increment the progress bar by 1 batch

            if test_loss < best_loss:
                best_model_state = self.model.state_dict()
                best_loss = test_loss

        progress_bar.close()

        # Close the progress bar at the end of training

        print(f"Model type: {self.decoder_type}.Total training epochs: {num_epochs}. Best test loss: {best_loss}.")
        self.model.load_state_dict(best_model_state)

        # Put models back to cpu
        self.model.cpu()
        self.aggregator.cpu()

    def _run_epoch(self, dataloader, eval=False):
        """
        Train the model using self-supervised contrastive learning.

        Args:
            dataloader (dataloader): dataloader with train/test data
        """
        if eval:
            self.model.eval()
        else:
            self.model.train()        

        total_loss = 0.0
        num_batches = 0

        for batch,label,dtypes,meta in dataloader:
            
            # Batch has shape (B, num_cols*2, embed_dim)
            #print("*"*40)
            #print("batch shape:", batch.shape, label.shape, dtypes.shape)
            batch_loss = 0
            batch = batch.squeeze().to(self.device)  # Move batch to device and remove first dim which is always 1
            label = label.squeeze().to(self.device)
            batch_size, num_columns = label.shape
            # Only train on columns with corresponding dtypes
            if self.decoder_type == 'categorical':
                col_to_train = torch.where(dtypes[0,:]==0)[0]
            else:
                col_to_train = torch.where(dtypes[0,:]==1)[0]

            # Note that some dataset might have no numerical/categorical column
            #num_of_col_to_train = max(int(len(col_to_train) * self.pct_of_col_to_train), len(col_to_train))
            num_of_col_to_train = 1
            perturbed_cols = col_to_train[torch.randperm(len(col_to_train))][:num_of_col_to_train]
            #print(col_to_train, len(col_to_train), num_of_col_to_train,self.pct_of_col_to_train, dtypes.shape, label.shape)

            for random_column_index in perturbed_cols:
            #random_column_index = random.choice(col_to_train.tolist())
                selected_column_label = label[:, random_column_index] # Only train decoding of one column at a time.
                #print(random_column_index)
                selected_column_name_emb = batch[0, 2*random_column_index, :]

                # Forward pass through the model to get row latent
                # Attention mask ignores masked embeddings(-1000)
                attention_mask = torch.all(batch != -1000, dim=-1).long().to(self.device)
                features = self.aggregator(batch, attention_mask)  # shape (B, num_latent, latent_dim)       

                # Pass through row latents to get decoded values
                decoded_values = self.model(features, selected_column_name_emb).squeeze() # shape (B, output_dim) for categorical or (B,) for numerical
                if self.decoder_type == 'categorical':
                    decoded_values = decoded_values.unsqueeze(1) # reshape to shape (B, 1, output_dim) for contrastive loss

                    # Add real value embedding as anchors
                    batch_this_col_only = batch[:,2*random_column_index:2*random_column_index+2, :]
                    attention_mask_this_col = torch.all(batch_this_col_only != -1000, dim=-1).long().to(self.device)
                    aggregated_batch_this_col = self.aggregator(batch_this_col_only, attention_mask_this_col).to(torch.bool).to(self.device)
                    reconstructed_this_col = self.model(aggregated_batch_this_col, selected_column_name_emb) # also have shape (B, 1, output_dim)
                    decoded_values = torch.cat([decoded_values, reconstructed_this_col], dim=1)
                    #real_emb_this_col = batch[:,2*random_column_index+1:2*random_column_index+2, :]
                    #decoded_values = torch.cat([decoded_values, real_emb_this_col], dim=1)
                    loss = self.criterion(features=decoded_values,label=selected_column_label)
                else:
                    loss = self.criterion(decoded_values,selected_column_label)
                #print(loss, batch_loss, num_of_col_to_train)

                if not eval and num_of_col_to_train > 0:
                    # Backward pass and optimization
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                batch_loss += loss.item() 

            if len(perturbed_cols) > 0:
                total_loss += batch_loss / len(perturbed_cols)
                num_batches += 1
            #print("total_loss updated:",total_loss, len(dataloader), len(perturbed_cols), num_batches)

        #print("total_loss:",total_loss, num_batches)
        avg_loss = total_loss / num_batches
        #print("avg_loss:",avg_loss)

        # Update the learning rate using the scheduler
        self.scheduler.step()

        #print("Finished one epoch!")

        return avg_loss
