from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import os
import time
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import Linear, Module, Parameter, ReLU, Sequential, functional, TripletMarginLoss
from torch.nn.functional import cross_entropy
from torch.optim import Adam
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go

from torch.autograd import Variable
from src.syngen.cttvae.util import DataTransformer, reparameterize, _loss_function_MMD,z_gen, triplet_loss_margin # src.syngen.cttvae.


class Encoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
      super(Encoder_T, self).__init__()
      # Input data to Transformer
      self.linear = nn.Linear(input_dim,embedding_dim)
      # Transformer Encoder
      self.transformerencoder_layer = nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
      self.encoder = nn.TransformerEncoder(self.transformerencoder_layer, num_layers=2)
      # Latent Space Representation
      self.fc_mu = nn.Linear(embedding_dim, latent_dim)
      self.fc_log_var = nn.Linear(embedding_dim, latent_dim)

    def forward(self, x):
      # Encoder
      x = self.linear(x)
      enc_output = self.encoder(x)
      # Latent Space Representation
      mu = self.fc_mu(enc_output)
      logvar = self.fc_log_var(enc_output)
      std = torch.exp(0.5 * logvar)
      return mu, std, logvar, enc_output


class Decoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
      super(Decoder_T, self).__init__()
      # Linear layer for mapping latent space to decoder input size
      # self.latent_to_decoder_input = nn.Linear(latent_dim, embedding_dim)
      self.latent_to_decoder_input = nn.Linear(latent_dim, embedding_dim)
      # Transformer Decoder
      self.transformerdecoder_layer = nn.TransformerDecoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
      self.decoder = nn.TransformerDecoder(self.transformerdecoder_layer, num_layers=2)
      # Transformer Embedding to input
      self.linear = nn.Linear(embedding_dim,input_dim)
      self.sigma = Parameter(torch.ones(input_dim) * 0.1)

    def forward(self, z, enc_output):
      # Encoder
      z_decoder_input = self.latent_to_decoder_input(z)
      # Decoder
      # Note: Pass enc_output (memory) to the decoder
      dec_output = self.decoder(z_decoder_input, enc_output)
      return self.linear(dec_output), self.sigma


class CTTVAETBS():

    def __init__(
        self,
        l2scale=1e-5,
        batch_size=500,
        epochs=300,
        loss_factor=2,
        triplet_factor=1,
        latent_dim =32,
        embedding_dim=128,
        nhead=8,
        dim_feedforward=1028,
        dropout=0.1,
        triplet_margin=0.2,
        tbs_strategy='linear_mix',
        lambda_scale=0.4,
        cuda=True,
        verbose=False,
        device='cuda'
    ):
        self.latent_dim=latent_dim
        self.embedding_dim = embedding_dim
        self.nhead=nhead
        self.dim_feedforward=dim_feedforward
        self.dropout=dropout
        self.triplet_margin = triplet_margin
        self.l2scale = l2scale
        self.batch_size = batch_size
        self.loss_factor = loss_factor
        self.triplet_factor = triplet_factor
        self.epochs = epochs
        self.verbose = verbose
        self.lambda_scale = lambda_scale
        self.tbs_strategy = tbs_strategy

        self._device = torch.device(device if cuda and torch.cuda.is_available() else 'cpu')
        print(f"USING DEVICE: {self._device}")

        self.data_sampler = None
        self.loss_values = None
        self.labels = None

    
    def save(self, save_path: str):
        """Save the TTVAE model to the given path."""
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(self, save_path)


    def visualize_latent_space(self, train_data, condition_column):
        """Visualize the latent space using PCA."""
        transformed_data = self.transformer.transform(train_data).astype('float32')
        tensor_data = torch.from_numpy(transformed_data).to(self._device)

        self.encoder.eval()
        with torch.no_grad():
            mean, std, logvar, enc_output = self.encoder(tensor_data)
            latent = mean.cpu().numpy()

        pca = PCA(n_components=2)
        projected = pca.fit_transform(latent)

        df_plot = pd.DataFrame(projected, columns=['PC1', 'PC2'])
        df_plot[f'{condition_column}'] = train_data[f'{condition_column}'].values

        fig = px.scatter(
            df_plot, x='PC1', y='PC2', color=df_plot[f'{condition_column}'].astype(str),
            title='Latent Space (PCA) - CTTVAE+TBS with Triplet Loss',
            opacity=0.7,
            labels={'color': f'{condition_column}'}
        )

        fig.update_layout(legend_title_text=f'{condition_column}')
        return fig


    # def _cond_loss(self, data, c, m):
    #     """Compute the cross entropy loss on the fixed discrete column."""
    #     loss = []
    #     st = 0
    #     st_c = 0
    #     for column_info in self.transformer.output_info_list:
    #         for span_info in column_info:
    #             if len(column_info) != 1 or span_info.activation_fn != 'softmax':
    #                 # not discrete column
    #                 st += span_info.dim
    #             else:
    #                 ed = st + span_info.dim
    #                 ed_c = st_c + span_info.dim
    #                 # TODO: change loss function to MMD
    #                 tmp = functional.cross_entropy(
    #                     data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
    #                 )
    #                 loss.append(tmp)
    #                 st = ed
    #                 st_c = ed_c

    #     loss = torch.stack(loss, dim=1)  # noqa: PD013

    #     return (loss * m).sum() / data.size()[0]


    # def fit(self, train_data, discrete_columns=(), cond_strategy='all_categories', condition_column=None, save_path=''):
    #     self.train_data_copy = train_data.copy()
    #     self.transformer = DataTransformer()
    #     self.transformer.fit(train_data, discrete_columns)
    #     self.train_data = self.transformer.transform(train_data).astype('float32')
    #     self.labels = train_data[condition_column].values.astype('int64')
    #     self.data_sampler = DataSampler(
    #         self.train_data, self.transformer.output_info_list, self.log_frequency,
    #         discrete_column_names=discrete_columns
    #     )

    #     data_tensor = torch.from_numpy(self.train_data).to(self._device)
    #     label_tensor = torch.tensor(self.labels).to(self._device)
    #     dataset = TensorDataset(data_tensor, label_tensor)
    #     loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False)

    #     data_dim = self.transformer.output_dimensions

    #     self.encoder = Encoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self._device)
    #     self.decoder = Decoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self._device)

    #     optimizer = Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale)
    #     scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=True)

    #     self.encoder.train()
    #     self.decoder.train()

    #     best_loss = float('inf')
    #     patience = 0

    #     self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

    #     for epoch in range(self.epochs):
    #         pbar = tqdm(loader, total=len(loader))
    #         pbar.set_description(f"Epoch {epoch+1}/{self.epochs}")

    #         batch_loss = 0.0
    #         len_input = 0

    #         for id_, (real_x, batch_labels) in enumerate(pbar):
    #             optimizer.zero_grad()

    #             mean, std, logvar, enc_output = self.encoder(real_x)
    #             z = reparameterize(mean, logvar)
    #             recon_x, sigmas = self.decoder(z, enc_output)

    #             # --- Conditional Loss via Training-by-Sampling ---
    #             cond_np, mask_np, _, _ = self.data_sampler.sample_condvec(
    #                 batch=real_x.size(0),
    #                 strategy='specific_column' if cond_strategy == 'specific' else 'all_categories',
    #                 condition_column=condition_column
    #             )
    #             cond = torch.from_numpy(cond_np).to(self._device)
    #             mask = torch.from_numpy(mask_np).to(self._device)

    #             cond_loss = self._cond_loss(recon_x, cond, mask)

    #             # Triplet loss
    #             triplet_loss = triplet_loss_margin(mean, batch_labels, margin=self.triplet_margin)

    #             # Total loss
    #             loss = _loss_function_MMD(recon_x, real_x, sigmas, mean, logvar,
    #                                       self.transformer.output_info_list, self.loss_factor)
    #             loss += cond_loss + triplet_loss

    #             batch_loss += loss.item() * len(real_x)
    #             len_input += len(real_x)

    #             loss.backward()
    #             optimizer.step()
    #             self.decoder.sigma.data.clamp_(0.01, 1.0)

    #             pbar.set_postfix({"Loss": loss.item()})

    #         curr_loss = batch_loss / len_input
    #         scheduler.step(curr_loss)

    #         if curr_loss < best_loss:
    #             best_loss = curr_loss
    #             patience = 0
    #             torch.save(self, save_path)
    #         else:
    #             patience += 1
    #             if patience == 20:
    #                 print('Early stopping')
    #                 break

    #         self.loss_values = pd.concat([self.loss_values,
    #                                       pd.DataFrame([{'Epoch': epoch + 1, 'Loss': curr_loss}])],
    #                                      ignore_index=True)

    #     # self.latent_space = self.visualize_latent_space(self.train_data_copy, condition_column)

    def fit(self, train_data, discrete_columns=(), condition_column=None, save_path=''):
        mix_lambda = self.lambda_scale
        tbs_strategy = self.tbs_strategy
        self.train_data_copy = train_data.copy()
        self.transformer = DataTransformer()
        self.transformer.fit(train_data, discrete_columns)
        self.train_data = self.transformer.transform(train_data).astype('float32')
        self.labels = train_data[condition_column].values.astype('int64')

        data_dim = self.transformer.output_dimensions

        self.encoder = Encoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead,
                                self.dim_feedforward, self.dropout).to(self._device)
        self.decoder = Decoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead,
                                self.dim_feedforward, self.dropout).to(self._device)

        optimizer = Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=True)

        self.encoder.train()
        self.decoder.train()

        best_loss = float('inf')
        patience = 0
        self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

        labels_np = self.labels
        class_counts = np.bincount(labels_np)
        orig_dist = class_counts / class_counts.sum()
        uniform_dist = np.ones_like(orig_dist) / len(orig_dist)

        if tbs_strategy == "linear_mix":
            pmf = mix_lambda * orig_dist + (1.0 - mix_lambda) * uniform_dist
        elif tbs_strategy == "uniform":
            pmf = uniform_dist
        elif tbs_strategy == "original":
            pmf = orig_dist
        else:
            raise ValueError(f"Unsupported TBS strategy: {tbs_strategy}")

        pmf = pmf / pmf.sum()  # ensure it sums to 1
        print(f"TBS strategy '{tbs_strategy}' → class PMF: {pmf}")

        for epoch in range(self.epochs):
            pbar = tqdm(range(len(self.train_data) // self.batch_size), total=len(self.train_data) // self.batch_size)
            pbar.set_description(f"Epoch {epoch+1}/{self.epochs}")

            batch_loss = 0.0
            len_input = 0

            for _ in pbar:
                # Sampling with class PMF 
                sampled_classes = np.random.choice(len(pmf), size=self.batch_size, p=pmf)

                real_x_list, batch_labels = [], []
                for cls in sampled_classes:
                    cls_indices = np.where(labels_np == cls)[0]
                    sampled_row_idx = np.random.choice(cls_indices)
                    real_x_list.append(self.train_data[sampled_row_idx])
                    batch_labels.append(cls)

                real_x = torch.tensor(np.stack(real_x_list), dtype=torch.float32).to(self._device)
                batch_labels = torch.tensor(batch_labels, dtype=torch.long).to(self._device)

                optimizer.zero_grad()
                mean, std, logvar, enc_output = self.encoder(real_x)
                z = reparameterize(mean, logvar)
                recon_x, sigmas = self.decoder(z, enc_output)

                recon_loss = _loss_function_MMD(recon_x, real_x, sigmas, mean, logvar,
                                                self.transformer.output_info_list, self.loss_factor)
                triplet_loss_val = triplet_loss_margin(mean, batch_labels, factor=self.triplet_factor, margin=self.triplet_margin)
                total_loss = recon_loss + triplet_loss_val

                batch_loss += total_loss.item() * len(real_x)
                len_input += len(real_x)

                total_loss.backward()
                optimizer.step()
                self.decoder.sigma.data.clamp_(0.01, 1.0)

                pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})

            curr_loss = batch_loss / len_input
            scheduler.step(curr_loss)

            if curr_loss < best_loss:
                best_loss = curr_loss
                patience = 0
                torch.save(self, save_path)
            else:
                patience += 1
                if patience == 20:
                    print('Early stopping')
                    break

            self.loss_values = pd.concat([
                self.loss_values,
                pd.DataFrame([{'Epoch': epoch + 1, 'Loss': curr_loss}])
            ], ignore_index=True)


    def sample(self, n_samples=100, condition_column=None, condition_value=None):
        """
        Sample data similar to training data, optionally conditioned on a discrete column value.
        """

        # Encode training data to get latent means
        self.encoder.eval()
        with torch.no_grad():
            enc_input = torch.Tensor(self.train_data).to(self._device)
            mean, std, logvar, enc_embed = self.encoder(enc_input)

        # Select relevant latent embeddings based on condition
        embeddings = torch.normal(mean=mean, std=std).cpu().numpy()
        if condition_column is not None and condition_value is not None:
            class_mask = self.train_data_copy[condition_column] == condition_value
            if class_mask.sum() == 0:
                raise ValueError(f"No samples found for: {condition_column} = {condition_value}")
            filtered_embeddings = embeddings[class_mask.values]
            filtered_enc_embed = enc_embed[class_mask.values]
        else:
            filtered_embeddings = embeddings
            filtered_enc_embed = enc_embed
        
        # Generate synthetic latent points by interpolation
        synthetic_embeddings = z_gen(
            filtered_embeddings, n_to_sample=n_samples,
            metric='minkowski', interpolation_method='triangle'
        )
        noise = torch.Tensor(synthetic_embeddings).to(self._device) #torch.tensor(synthetic_embeddings, dtype=torch.float32, device=self._device)

        self.decoder.eval()
        with torch.no_grad():
            fake, _ = self.decoder(noise, filtered_enc_embed)
            fake = torch.tanh(fake).cpu().numpy()

        return self.transformer.inverse_transform(fake)
    
    # def sample(self, n_samples=100, condition_column=None, condition_value=None):
    #     """
    #     Sample data similar to training data, optionally conditioned on a discrete column value.
    #     """

    #     # Encode training data to get latent means
    #     self.encoder.eval()
    #     with torch.no_grad():
    #         enc_input = torch.Tensor(self.train_data).to(self._device)
    #         mean, std, logvar, enc_embed = self.encoder(enc_input)

    #     # Select relevant latent embeddings based on condition
    #     embeddings = torch.normal(mean=mean, std=std).cpu().numpy()
    #     if condition_column is not None and condition_value is not None:
    #         class_mask = self.train_data_copy[condition_column] == condition_value
    #         if class_mask.sum() == 0:
    #             raise ValueError(f"No samples found for: {condition_column} = {condition_value}")
    #         filtered_embeddings = embeddings[class_mask.values]
    #     else:
    #         filtered_embeddings = embeddings
        
    #     # Generate synthetic latent points by interpolation
    #     synthetic_embeddings = z_gen(
    #         filtered_embeddings, n_to_sample=n_samples,
    #         metric='minkowski', interpolation_method='triangle'
    #     )
    #     noise = torch.Tensor(synthetic_embeddings).to(self._device) #torch.tensor(synthetic_embeddings, dtype=torch.float32, device=self._device)

    #     self.decoder.eval()
    #     with torch.no_grad():
    #         fake, _ = self.decoder(noise, enc_embed)
    #         fake = torch.tanh(fake).cpu().numpy()

    #     return self.transformer.inverse_transform(fake)


    def set_device(self, device):
        """Set the `device` to be used ('GPU' or 'CPU)."""
        self._device = device
        self.decoder.to(self._device)
