import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import re

class TabularDataset(Dataset):
    def __init__(self, features, labels):
        labels = np.array(labels)
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class LGMVAE(nn.Module):
    def __init__(self, input_dim, continuous_indices, binary_indices, z_dim, c_dim, y_dim, beta_1=1.0, beta_2=1.0):
        super().__init__()
        self.input_dim = input_dim
        self.continuous_indices = continuous_indices
        self.binary_indices = binary_indices
        self.continuous_dim = len(continuous_indices)
        self.binary_dim = len(binary_indices)
        
        self.z_dim, self.c_dim, self.y_dim = z_dim, c_dim, y_dim
        self.beta_1, self.beta_2 = beta_1, beta_2

        if c_dim % y_dim != 0: raise ValueError("c_dim must be a multiple of y_dim")
        self.clusters_per_class = c_dim // y_dim

        # Inference Network (Encoder)
        self.qc_net = nn.Sequential(nn.Linear(input_dim + y_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, c_dim))
        self.qz_net = nn.Sequential(nn.Linear(input_dim + c_dim + y_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU())
        self.qz_mean, self.qz_logvar = nn.Linear(512, z_dim), nn.Linear(512, z_dim)

        # Generative Network (Decoder)
        self.pc_mean, self.pc_logvar = nn.Linear(c_dim, z_dim), nn.Linear(c_dim, z_dim)
        self.px_net_core = nn.Sequential(nn.Linear(z_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU())
        
        if self.continuous_dim > 0:
            self.px_head_continuous = nn.Linear(512, self.continuous_dim)
        if self.binary_dim > 0:
            self.px_head_binary = nn.Linear(512, self.binary_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y_one_hot):
        batch_size = x.size(0)

        # Inference
        x_y = torch.cat([x, y_one_hot], dim=1)
        qc_logits = self.qc_net(x_y)
        qc_probs = F.softmax(qc_logits, dim=1)
        
        c_cats = torch.eye(self.c_dim, device=x.device).unsqueeze(0).expand(batch_size, -1, -1)
        x_expanded_for_encoder = x.unsqueeze(1).expand(batch_size, self.c_dim, -1) # Renamed for clarity
        y_expanded = y_one_hot.unsqueeze(1).expand(batch_size, self.c_dim, -1)
        
        x_c_y = torch.cat([x_expanded_for_encoder, c_cats, y_expanded], dim=2).reshape(-1, self.input_dim + self.c_dim + self.y_dim)
        c_flat = c_cats.reshape(-1, self.c_dim)
        
        qz_hidden = self.qz_net(x_c_y)
        qz_mu, qz_logvar = self.qz_mean(qz_hidden), self.qz_logvar(qz_hidden)
        z_sample = self.reparameterize(qz_mu, qz_logvar)

        # Generative
        pc_mu, pc_logvar = self.pc_mean(c_flat), self.pc_logvar(c_flat)
        decoder_hidden = self.px_net_core(z_sample)

        recon_loss = 0.0
        
        if self.continuous_dim > 0:
            px_recon_continuous = self.px_head_continuous(decoder_hidden).reshape(batch_size, self.c_dim, -1)
            x_continuous = x[:, self.continuous_indices]
            x_cont_expanded = x_continuous.unsqueeze(1).expand(-1, self.c_dim, -1)
            recon_loss += F.mse_loss(px_recon_continuous, x_cont_expanded, reduction='none').sum(dim=2)
            
        if self.binary_dim > 0:
            px_recon_binary_logits = self.px_head_binary(decoder_hidden).reshape(batch_size, self.c_dim, -1)
            x_binary = x[:, self.binary_indices]
            x_binary_expanded = x_binary.unsqueeze(1).expand(-1, self.c_dim, -1)
            recon_loss += F.binary_cross_entropy_with_logits(px_recon_binary_logits, x_binary_expanded, reduction='none').sum(dim=2)

        # KL divergence and final loss calculation
        kl_z = 0.5 * torch.sum(pc_logvar.reshape(batch_size, self.c_dim, -1) - qz_logvar.reshape(batch_size, self.c_dim, -1) - 1 + (qz_logvar.reshape(batch_size, self.c_dim, -1).exp() + (qz_mu.reshape(batch_size, self.c_dim, -1) - pc_mu.reshape(batch_size, self.c_dim, -1)).pow(2)) / pc_logvar.reshape(batch_size, self.c_dim, -1).exp(), dim=2)
        y_labels = torch.argmax(y_one_hot, dim=1)
        pc_prior = torch.zeros_like(qc_probs)
        for i in range(batch_size):
            label = y_labels[i]
            start_cluster, end_cluster = label * self.clusters_per_class, (label + 1) * self.clusters_per_class
            pc_prior[i, start_cluster:end_cluster] = 1.0 / self.clusters_per_class
        
        kl_c = torch.sum(qc_probs * (torch.log(qc_probs + 1e-10) - torch.log(pc_prior + 1e-10)), dim=1)
        loss_per_c = torch.sum(qc_probs * (recon_loss + self.beta_1 * kl_z), dim=1)
        final_loss = torch.mean(loss_per_c + self.beta_2 * kl_c)
        
        mean_recon_loss = torch.mean(recon_loss) if isinstance(recon_loss, torch.Tensor) else 0.0
        return final_loss, mean_recon_loss, self.beta_1 * torch.mean(kl_z), self.beta_2 * torch.mean(kl_c)

    def sample(self, y_label, num_samples=1):
        self.eval()
        with torch.no_grad():
            start_cluster, end_cluster = y_label * self.clusters_per_class, (y_label + 1) * self.clusters_per_class
            cluster_range = list(range(start_cluster, end_cluster))
            num_clusters_in_class = len(cluster_range)
            if num_clusters_in_class == 0: return torch.empty(0, self.input_dim)
            samples_per_cluster = num_samples // num_clusters_in_class
            remainder = num_samples % num_clusters_in_class
            c_idx_list = []
            for i, cluster_idx in enumerate(cluster_range):
                num_to_sample = samples_per_cluster + (1 if i < remainder else 0)
                if num_to_sample > 0: c_idx_list.extend([cluster_idx] * num_to_sample)
            
            c_idx = torch.tensor(c_idx_list, dtype=torch.long, device=next(self.parameters()).device)
            c_idx = c_idx[torch.randperm(len(c_idx))]
            
            c_cats = F.one_hot(c_idx, num_classes=self.c_dim).float()
            pc_mu, pc_logvar = self.pc_mean(c_cats), self.pc_logvar(c_cats)
            z = self.reparameterize(pc_mu, pc_logvar)
            
            decoder_hidden = self.px_net_core(z)
            recon_full = torch.zeros(num_samples, self.input_dim, device=z.device)

            if self.continuous_dim > 0:
                recon_features_continuous = self.px_head_continuous(decoder_hidden)
                recon_full[:, self.continuous_indices] = recon_features_continuous
            
            if self.binary_dim > 0:
                recon_features_binary_logits = self.px_head_binary(decoder_hidden)
                recon_features_binary = torch.round(torch.sigmoid(recon_features_binary_logits))
                recon_full[:, self.binary_indices] = recon_features_binary
            
            return recon_full

# util function
def validate_epoch(model, val_loader, device):
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for data, labels in val_loader:
            data, labels = data.to(device), labels.to(device)
            y_one_hot = F.one_hot(labels, num_classes=model.y_dim).float()
            loss, _, _, _ = model(data, y_one_hot)
            total_val_loss += loss.item()
    return total_val_loss / len(val_loader)

def get_cluster_centroids(model):
    model.eval()
    with torch.no_grad():
        all_clusters_one_hot = torch.eye(model.c_dim).to(next(model.parameters()).device)
        centroids = model.pc_mean(all_clusters_one_hot)
    return centroids

def reconstruct_from_centroids(model, centroids):
    model.eval()
    with torch.no_grad():
        decoder_hidden = model.px_net_core(centroids)
        
        recon_full = torch.zeros(model.c_dim, model.input_dim, device=centroids.device)

        if model.continuous_dim > 0:
            recon_continuous = model.px_head_continuous(decoder_hidden)
            recon_full[:, model.continuous_indices] = recon_continuous
        
        if model.binary_dim > 0:
            recon_binary_logits = model.px_head_binary(decoder_hidden)

            # Convert logits to probabilities and then round to the nearest integer (0 or 1)
            recon_binary = torch.round(torch.sigmoid(recon_binary_logits))
            
            recon_full[:, model.binary_indices] = recon_binary
            
    return recon_full.cpu().numpy()


def get_clusters_for_label(class_label, c_dim, y_dim):
    if c_dim % y_dim != 0:
        raise ValueError("c_dim must be a multiple of y_dim for even cluster distribution.")
    if not (0 <= class_label < y_dim):
        raise ValueError(f"class_label must be an integer between 0 and {y_dim - 1}.")
    clusters_per_class = c_dim // y_dim
    start_cluster = class_label * clusters_per_class
    end_cluster = start_cluster + clusters_per_class
    return list(range(start_cluster, end_cluster))

def predict_cluster_for_input(model, input_point, input_label, device):
    model.eval()
    with torch.no_grad():
        x = input_point.view(1, -1).to(device)
        
        y_one_hot = F.one_hot(torch.tensor([input_label]), num_classes=model.y_dim).float().to(device)
        x_y = torch.cat([x, y_one_hot], dim=1)
        cluster_logits = model.qc_net(x_y)

        cluster_probs = F.softmax(cluster_logits, dim=1).squeeze(0)
        predicted_cluster = torch.argmax(cluster_probs).item()

    return cluster_probs, predicted_cluster

def generate_counterfactuals_with_steps_centroid(model, input_point, input_label, target_label, all_centroids, device):
    model.eval()
    with torch.no_grad():
        # Get source cluster and latent z for the input
        _, source_cluster_idx = predict_cluster_for_input(model, input_point, input_label, device)
        
        x_flat = input_point.view(1, -1).to(device)
        y_source_one_hot = F.one_hot(torch.tensor([input_label]), num_classes=model.y_dim).float().to(device)
        c_source_one_hot = F.one_hot(torch.tensor([source_cluster_idx]), num_classes=model.c_dim).float().to(device)
        
        x_c_y_source = torch.cat([x_flat, c_source_one_hot, y_source_one_hot], dim=1)
        z_source = model.qz_mean(model.qz_net(x_c_y_source))

        # Get cluster indices and define tau values
        target_cluster_indices = get_clusters_for_label(target_label, model.c_dim, model.y_dim)
        tau_values = torch.arange(0.0, 1.01, 0.05).to(device)
        
        all_z_counterfactuals = []

        # Generate a trajectory of z's for each target cluster
        for target_cluster_idx in target_cluster_indices:
            target_centroid = all_centroids[target_cluster_idx]
            z_trajectory = (1.0 - tau_values.view(-1, 1)) * z_source + tau_values.view(-1, 1) * target_centroid
            all_z_counterfactuals.append(z_trajectory)

        # Decode using the multi-headed decoder
        batch_z = torch.cat(all_z_counterfactuals, dim=0)
        
        decoder_hidden = model.px_net_core(batch_z)
        
        cf_points = torch.zeros(batch_z.size(0), model.input_dim, device=device)
        
        # Conditionally generate and place continuous features
        if model.continuous_dim > 0:
            recon_continuous = model.px_head_continuous(decoder_hidden)
            cf_points[:, model.continuous_indices] = recon_continuous
        
        # Conditionally generate and place binary features
        if model.binary_dim > 0:
            recon_binary_logits = model.px_head_binary(decoder_hidden)
            recon_binary = torch.round(torch.sigmoid(recon_binary_logits))
            cf_points[:, model.binary_indices] = recon_binary

        num_clusters = len(target_cluster_indices)
        num_tau_steps = len(tau_values)
        final_shape = (num_clusters, num_tau_steps, model.input_dim)
        
        return cf_points.view(final_shape)