from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import os
import time
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import Linear, Module, Parameter, ReLU, Sequential, functional, TripletMarginLoss
from torch.nn.functional import cross_entropy
from torch.optim import Adam
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go

from torch.autograd import Variable
from src.syngen.cttvae.util import DataTransformer, reparameterize, _loss_function_MMD,z_gen, triplet_loss_margin 


class Encoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
      super(Encoder_T, self).__init__()
      # Input data to Transformer
      self.linear = nn.Linear(input_dim,embedding_dim)
      # Transformer Encoder
      self.transformerencoder_layer = nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
      self.encoder = nn.TransformerEncoder(self.transformerencoder_layer, num_layers=2)
      # Latent Space Representation
      self.fc_mu = nn.Linear(embedding_dim, latent_dim)
      self.fc_log_var = nn.Linear(embedding_dim, latent_dim)

    def forward(self, x):
      # Encoder
      x = self.linear(x)
      enc_output = self.encoder(x)
      # Latent Space Representation
      mu = self.fc_mu(enc_output)
      logvar = self.fc_log_var(enc_output)
      std = torch.exp(0.5 * logvar)
      return mu, std, logvar, enc_output


class Decoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
      super(Decoder_T, self).__init__()
      # Linear layer for mapping latent space to decoder input size
      self.latent_to_decoder_input = nn.Linear(latent_dim, embedding_dim)
      # Transformer Decoder
      self.transformerdecoder_layer = nn.TransformerDecoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
      self.decoder = nn.TransformerDecoder(self.transformerdecoder_layer, num_layers=2)
      # Transformer Embedding to input
      self.linear = nn.Linear(embedding_dim,input_dim)
      self.sigma = Parameter(torch.ones(input_dim) * 0.1)

    def forward(self, z, enc_output):
      # Encoder
      z_decoder_input = self.latent_to_decoder_input(z)
      # Decoder
      dec_output = self.decoder(z_decoder_input, enc_output)
      return self.linear(dec_output), self.sigma


class CTTVAE():

    def __init__(
        self,
        l2scale=1e-5,
        batch_size=500,
        epochs=300,
        loss_factor=2,
        triplet_factor=1,
        latent_dim =32,
        embedding_dim=128,
        nhead=8,
        dim_feedforward=1028,
        dropout=0.1,
        triplet_margin=0.2,
        cuda=True,
        verbose=False,
        device='cuda'
    ):
        self.latent_dim=latent_dim
        self.embedding_dim = embedding_dim
        self.nhead=nhead
        self.dim_feedforward=dim_feedforward
        self.dropout=dropout
        self.triplet_margin = triplet_margin
        self.l2scale = l2scale
        self.batch_size = batch_size
        self.loss_factor = loss_factor
        self.triplet_factor = triplet_factor
        self.epochs = epochs
        self.verbose = verbose
        self._device = torch.device(device if cuda and torch.cuda.is_available() else 'cpu')

        self.loss_values = None
        self.latent_space = None
        self.labels = None

    
    def save(self, save_path: str):
        """Save the TTVAE model to the given path."""
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(self, save_path)


    def visualize_latent_space(self, train_data, condition_column):
        """Visualize the latent space using PCA."""
        transformed_data = self.transformer.transform(train_data).astype('float32')
        tensor_data = torch.from_numpy(transformed_data).to(self._device)

        self.encoder.eval()
        with torch.no_grad():
            mean, std, logvar, enc_output = self.encoder(tensor_data)
            latent = mean.cpu().numpy()

        pca = PCA(n_components=2)
        projected = pca.fit_transform(latent)

        df_plot = pd.DataFrame(projected, columns=['PC1', 'PC2'])
        df_plot[f'{condition_column}'] = train_data[f'{condition_column}'].values

        fig = px.scatter(
            df_plot, x='PC1', y='PC2', color=df_plot[f'{condition_column}'].astype(str),
            title='Latent Space (PCA) - CTTVAE with Triplet Loss',
            opacity=0.7,
            labels={'color': f'{condition_column}'}
        )

        fig.update_layout(legend_title_text=f'{condition_column}')
        return fig


    def fit(self, train_data, discrete_columns=(), cond_strategy='all_categories', condition_column=None, save_path=''):
        self.train_data_copy = train_data.copy()
        self.transformer = DataTransformer()
        self.transformer.fit(train_data, discrete_columns)
        self.train_data = self.transformer.transform(train_data).astype('float32')
        self.labels = train_data[condition_column].values.astype('int64')
        
        data_tensor = torch.from_numpy(self.train_data).to(self._device)
        label_tensor = torch.tensor(self.labels).to(self._device)
        dataset = TensorDataset(data_tensor, label_tensor)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)

        data_dim = self.transformer.output_dimensions
        
        self.encoder = Encoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self._device)
        self.decoder = Decoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self._device)

        optimizer = Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=True)

        self.encoder.train()
        self.decoder.train()

        best_loss = float('inf')
        patience = 0

        self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

        for epoch in range(self.epochs):
            pbar = tqdm(loader, total=len(loader))
            pbar.set_description(f"Epoch {epoch+1}/{self.epochs}")

            batch_loss = 0.0
            len_input = 0

            for id_, (real_x, batch_labels) in enumerate(pbar):
                optimizer.zero_grad()

                mean, std, logvar, enc_output = self.encoder(real_x)
                z = reparameterize(mean, logvar)

                recon_x, sigmas = self.decoder(z, enc_output)

                # reconstruction + KL + MMD loss
                loss = _loss_function_MMD(recon_x, real_x, sigmas, mean, logvar,
                                          self.transformer.output_info_list, self.loss_factor)

                # Triplet loss on latent space
                triplet_loss = triplet_loss_margin(mean, batch_labels, factor=self.triplet_factor, margin=self.triplet_margin)
                loss += triplet_loss

                batch_loss += loss.item() * len(real_x)
                len_input += len(real_x)

                loss.backward()
                optimizer.step()
                self.decoder.sigma.data.clamp_(0.01, 1.0)

                pbar.set_postfix({"Loss": loss.item()})

            curr_loss = batch_loss / len_input
            scheduler.step(curr_loss)

            if curr_loss < best_loss:
                best_loss = loss.item()
                patience = 0
                torch.save(self, save_path) # torch.save(self, save_path+'/model.pt')
            else:
                patience += 1
                if patience == 20:
                    print('Early stopping')
                    break

            self.loss_values = pd.concat([self.loss_values,
                                          pd.DataFrame([{'Epoch': epoch + 1, 'Loss': curr_loss}])],
                                         ignore_index=True)
        self.latent_space = self.visualize_latent_space(self.train_data_copy, condition_column)
            

    def sample(self, n_samples=100, condition_column=None, condition_value=None):
        """
        Sample data similar to training data, optionally conditioned on a discrete column value.
        """

        # Encode training data to get latent means
        self.encoder.eval()
        with torch.no_grad():
            enc_input = torch.Tensor(self.train_data).to(self._device)
            mean, std, logvar, enc_embed = self.encoder(enc_input)

        # Select relevant latent embeddings based on condition
        embeddings = torch.normal(mean=mean, std=std).cpu().numpy()
        if condition_column is not None and condition_value is not None:
            class_mask = self.train_data_copy[condition_column] == condition_value
            if class_mask.sum() == 0:
                raise ValueError(f"No samples found for: {condition_column} = {condition_value}")
            filtered_embeddings = embeddings[class_mask.values]
            filtered_enc_embed = enc_embed[class_mask.values]
        else:
            filtered_embeddings = embeddings
            filtered_enc_embed = enc_embed
        
        # Generate synthetic latent points by interpolation
        synthetic_embeddings = z_gen(
            filtered_embeddings, n_to_sample=n_samples,
            metric='minkowski', interpolation_method='triangle'
        )
        noise = torch.Tensor(synthetic_embeddings).to(self._device) #torch.tensor(synthetic_embeddings, dtype=torch.float32, device=self._device)

        self.decoder.eval()
        with torch.no_grad():
            fake, _ = self.decoder(noise, filtered_enc_embed)
            fake = torch.tanh(fake).cpu().numpy()

        return self.transformer.inverse_transform(fake)

    # def sample(self, n_samples=100, condition_column=None, condition_value=None):
    #     """
    #     Sample data similar to training data, optionally conditioned on a discrete column value.
    #     If no condition is provided, samples are generated proportionally to the class distribution.
    #     """

    #     # Encode training data to get latent means
    #     self.encoder.eval()
    #     with torch.no_grad():
    #         enc_input = torch.Tensor(self.train_data).to(self._device)
    #         mean, std, logvar, enc_embed = self.encoder(enc_input)

    #     embeddings = torch.normal(mean=mean, std=std).cpu().numpy()

    #     if condition_column is not None and condition_value is not None:
    #         class_mask = self.train_data_copy[condition_column] == condition_value
    #         if class_mask.sum() == 0:
    #             raise ValueError(f"No samples found for: {condition_column} = {condition_value}")
    #         filtered_embeddings = embeddings[class_mask.values]
    #         synthetic_embeddings = z_gen(
    #             filtered_embeddings, n_to_sample=n_samples,
    #             metric='minkowski', interpolation_method='triangle'
    #         )
    #     else:
    #         print("No condition provided, sampling proportionally to class distribution.")
    #         # Sampling proportionally to the original class distribution
    #         class_counts = self.train_data_copy["Exited"].value_counts(normalize=True)
    #         print(f"Class distribution: {class_counts.to_dict()}")
    #         synthetic_embeddings_list = []

    #         for class_value, prop in class_counts.items():
    #             n_class_samples = int(round(prop * n_samples))
    #             print(f"Sampling {n_class_samples} samples for class {class_value}")
    #             class_mask = self.train_data_copy["Exited"] == class_value
    #             class_embeddings = embeddings[class_mask.values]
    #             syn_emb = z_gen(
    #                 class_embeddings, n_to_sample=n_class_samples,
    #                 metric='minkowski', interpolation_method='triangle'
    #             )
    #             print(f"syn_emb shape: {syn_emb.shape}")
    #             synthetic_embeddings_list.append(syn_emb)

    #         synthetic_embeddings = np.vstack(synthetic_embeddings_list)
    #         print(f"Total synthetic embeddings shape: {synthetic_embeddings.shape}")

    #     noise = torch.Tensor(synthetic_embeddings).to(self._device)
    
    #     self.decoder.eval()
    #     with torch.no_grad():
    #         fake, _ = self.decoder(noise, enc_embed)
    #         fake = torch.tanh(fake).cpu().numpy()
    
    #     return self.transformer.inverse_transform(fake)



    def set_device(self, device):
        """Set the `device` to be used ('GPU' or 'CPU)."""
        self._device = device
        self.decoder.to(self._device)
