from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import os
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import Parameter
from torch.optim import Adam
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go

from src.syngen.baselines.ttvae.util import DataTransformer, reparameterize, _loss_function_MMD, z_gen

class Encoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.linear = nn.Linear(input_dim, embedding_dim)
        self.transformerencoder_layer = nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.transformerencoder_layer, num_layers=2)
        self.fc_mu = nn.Linear(embedding_dim, latent_dim)
        self.fc_log_var = nn.Linear(embedding_dim, latent_dim)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.linear(x)
        enc_output = self.encoder(x)
        mu = self.fc_mu(enc_output)
        logvar = self.fc_log_var(enc_output)
        std = torch.exp(0.5 * logvar)
        return mu.squeeze(1), std.squeeze(1), logvar.squeeze(1), enc_output

class Decoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.latent_to_decoder_input = nn.Linear(latent_dim, embedding_dim)
        self.transformerdecoder_layer = nn.TransformerDecoderLayer(embedding_dim, nhead, dim_feedforward, dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(self.transformerdecoder_layer, num_layers=2)
        self.linear = nn.Linear(embedding_dim, input_dim)
        self.sigma = Parameter(torch.ones(input_dim) * 0.1)

    def forward(self, z, enc_output):
        z = z.unsqueeze(1)
        z_decoder_input = self.latent_to_decoder_input(z)
        dec_output = self.decoder(z_decoder_input, enc_output)
        return self.linear(dec_output).squeeze(1), self.sigma

class TTVAETBS:
    def __init__(self, l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, latent_dim=32, verbose=True,
                 embedding_dim=128, nhead=8, dim_feedforward=1028, dropout=0.1, tbs_strategy='linear_mix', lambda_scale=0.4,cuda=True, device='cuda'):
        self.latent_dim = latent_dim
        self.embedding_dim = embedding_dim
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.l2scale = l2scale
        self.batch_size = batch_size
        self.loss_factor = loss_factor
        self.tbs_strategy = tbs_strategy
        self.lambda_scale = lambda_scale
        self.epochs = epochs
        self.verbose = verbose
        self._device = torch.device(device if cuda and torch.cuda.is_available() else 'cpu')

        self.loss_values = None
        self.labels = None

    def _cond_loss(self, data, c, m):
        loss = []
        st = 0
        st_c = 0
        for column_info in self.transformer.output_info_list:
            for span_info in column_info:
                if len(column_info) != 1 or span_info.activation_fn != 'softmax':
                    st += span_info.dim
                else:
                    ed = st + span_info.dim
                    ed_c = st_c + span_info.dim
                    tmp = F.cross_entropy(data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none')
                    loss.append(tmp)
                    st = ed
                    st_c = ed_c
        loss = torch.stack(loss, dim=1)
        return (loss * m).sum() / data.size(0)

    def save(self, save_path: str):
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(self, save_path)

    def visualize_latent_space(self, train_data, condition_column):
        """Visualize the latent space using PCA."""
        transformed_data = self.transformer.transform(train_data).astype('float32')
        tensor_data = torch.from_numpy(transformed_data).to(self._device)

        self.encoder.eval()
        with torch.no_grad():
            mean, std, logvar, enc_output = self.encoder(tensor_data)
            latent = mean.cpu().numpy()

        pca = PCA(n_components=2)
        projected = pca.fit_transform(latent)

        df_plot = pd.DataFrame(projected, columns=['PC1', 'PC2'])
        df_plot[f'{condition_column}'] = train_data[f'{condition_column}'].values

        fig = px.scatter(
            df_plot, x='PC1', y='PC2', color=df_plot[f'{condition_column}'].astype(str),
            title='Latent Space (PCA) - TTVAE',
            opacity=0.7,
            labels={'color': f'{condition_column}'}
        )

        fig.update_layout(legend_title_text=f'{condition_column}')
        return fig

    def fit(self, train_data, discrete_columns=(), condition_column=None, save_path=''):
        mix_lambda = self.lambda_scale
        tbs_strategy = self.tbs_strategy
        self.train_data_copy = train_data.copy()
        self.transformer = DataTransformer()
        self.transformer.fit(train_data, discrete_columns=discrete_columns)
        self.train_data = self.transformer.transform(train_data).astype('float32')
        self.labels = train_data[condition_column].values.astype('int64')

        data_dim = self.transformer.output_dimensions

        dataset = TensorDataset(torch.from_numpy(self.train_data).float())
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        self.encoder = Encoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self._device)
        self.decoder = Decoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self._device)

        optimizer = Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), weight_decay=self.l2scale)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=20, verbose=True)

        self.encoder.train()
        self.decoder.train()

        best_loss = float('inf')
        patience = 0
        self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])

        labels_np = self.labels
        class_counts = np.bincount(labels_np)
        orig_dist = class_counts / class_counts.sum()
        uniform_dist = np.ones_like(orig_dist) / len(orig_dist)

        if tbs_strategy == "linear_mix":
            pmf = mix_lambda * orig_dist + (1.0 - mix_lambda) * uniform_dist
        elif tbs_strategy == "uniform":
            pmf = uniform_dist
        elif tbs_strategy == "original":
            pmf = orig_dist
        else:
            raise ValueError(f"Unsupported TBS strategy: {tbs_strategy}")

        pmf = pmf / pmf.sum()  # ensure it sums to 1
        print(f"TBS strategy '{tbs_strategy}' → class PMF: {pmf}")

        for epoch in range(self.epochs):
            pbar = tqdm(loader, total=len(loader))
            pbar.set_description(f"Epoch {epoch + 1}/{self.epochs}")

            batch_loss = 0.0
            len_input = 0

            for _ in pbar:
                # Sampling with class PMF 
                sampled_classes = np.random.choice(len(pmf), size=self.batch_size, p=pmf)

                real_x_list, batch_labels = [], []
                for cls in sampled_classes:
                    cls_indices = np.where(labels_np == cls)[0]
                    sampled_row_idx = np.random.choice(cls_indices)
                    real_x_list.append(self.train_data[sampled_row_idx])
                    batch_labels.append(cls)

                real_x = torch.tensor(np.stack(real_x_list), dtype=torch.float32).to(self._device)
                batch_labels = torch.tensor(batch_labels, dtype=torch.long).to(self._device)

                optimizer.zero_grad()
                mean, std, logvar, enc_output = self.encoder(real_x)
                z = reparameterize(mean, logvar)
                recon_x, sigmas = self.decoder(z, enc_output)

                loss = _loss_function_MMD(recon_x, real_x, sigmas, mean, logvar,
                                                self.transformer.output_info_list, self.loss_factor)
                

                batch_loss += loss.item() * len(real_x)
                len_input += len(real_x)

                loss.backward()
                optimizer.step()
                self.decoder.sigma.data.clamp_(0.01, 1.0)
                pbar.set_postfix({"Loss": loss.item()})

            curr_loss = batch_loss / len_input
            scheduler.step(curr_loss)

            if curr_loss < best_loss:
                best_loss = curr_loss
                patience = 0
                torch.save(self, save_path)
            else:
                patience += 1
                if patience == 20:
                    print('Early stopping')
                    break

            self.loss_values = pd.concat([self.loss_values,
                                          pd.DataFrame([{'Epoch': epoch + 1, 'Loss': curr_loss}])],
                                         ignore_index=True)

        # self.latent_space = self.visualize_latent_space(self.train_data_copy, condition_column)

    def sample(self, n_samples=100):
        """Sample data similar to the training data.

        """
        self.encoder.eval()
        with torch.no_grad():
            enc_input = torch.from_numpy(self.train_data).float().to(self._device)
            mean, std, _, _ = self.encoder(enc_input)

        embeddings = torch.normal(mean=mean, std=std).cpu().numpy()
        synthetic_embeddings=z_gen(embeddings,n_to_sample=n_samples,metric='minkowski',interpolation_method='triangle')
        noise = torch.Tensor(synthetic_embeddings).to(self._device)

        self.decoder.eval()
        with torch.no_grad():
            dummy_memory = torch.zeros((n_samples, 1, self.embedding_dim), device=self._device)
            fake, sigmas = self.decoder(noise, dummy_memory)
            fake = torch.tanh(fake).cpu().numpy()

        return self.transformer.inverse_transform(fake)

    def set_device(self, device):
        self._device = torch.device(device)
        self.decoder.to(self._device)
