import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import re


class TabularDataset(Dataset):
    """Custom PyTorch Dataset for tabular data."""
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class LGMVAE(nn.Module):
    def __init__(self, input_dim, z_dim, c_dim, y_dim, beta_1=1.0, beta_2=1.0):
        super().__init__()
        self.input_dim, self.z_dim, self.c_dim, self.y_dim = input_dim, z_dim, c_dim, y_dim
        self.beta_1, self.beta_2 = beta_1, beta_2

        if c_dim % y_dim != 0:
            raise ValueError("c_dim must be a multiple of y_dim for even cluster distribution.")
        self.clusters_per_class = c_dim // y_dim

        # Inference Network: q(c|x,y) and q(z|x,c,y)
        self.qc_net = nn.Sequential(nn.Linear(input_dim + y_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, c_dim))
        self.qz_net = nn.Sequential(nn.Linear(input_dim + c_dim + y_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU())
        self.qz_mean = nn.Linear(512, z_dim)
        self.qz_logvar = nn.Linear(512, z_dim)

        # Generative Network: p(z|c) and p(x|z)
        self.pc_mean = nn.Linear(c_dim, z_dim)
        self.pc_logvar = nn.Linear(c_dim, z_dim)
        self.px_net = nn.Sequential(nn.Linear(z_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, input_dim))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, y_one_hot):
        batch_size = x.size(0)
        x_y = torch.cat([x, y_one_hot], dim=1)
        qc_logits = self.qc_net(x_y)
        qc_probs = F.softmax(qc_logits, dim=1)

        c_cats = torch.eye(self.c_dim, device=x.device).unsqueeze(0).expand(batch_size, -1, -1)
        x_expanded = x.unsqueeze(1).expand(-1, self.c_dim, -1)
        y_expanded = y_one_hot.unsqueeze(1).expand(-1, self.c_dim, -1)

        x_c_y = torch.cat([x_expanded, c_cats, y_expanded], dim=2).reshape(-1, self.input_dim + self.c_dim + self.y_dim)
        c_flat = c_cats.reshape(-1, self.c_dim)

        qz_hidden = self.qz_net(x_c_y)
        qz_mu, qz_logvar = self.qz_mean(qz_hidden), self.qz_logvar(qz_hidden)
        z_sample = self.reparameterize(qz_mu, qz_logvar)

        pc_mu, pc_logvar = self.pc_mean(c_flat), self.pc_logvar(c_flat)
        px_recon = self.px_net(z_sample)

        qz_mu, qz_logvar = qz_mu.reshape(batch_size, self.c_dim, -1), qz_logvar.reshape(batch_size, self.c_dim, -1)
        pc_mu, pc_logvar = pc_mu.reshape(batch_size, self.c_dim, -1), pc_logvar.reshape(batch_size, self.c_dim, -1)
        px_recon = px_recon.reshape(batch_size, self.c_dim, -1)

        recon_loss = F.mse_loss(px_recon, x_expanded, reduction='none').sum(dim=2)
        
        kl_z = 0.5 * torch.sum(pc_logvar - qz_logvar - 1 + (qz_logvar.exp() + (qz_mu - pc_mu).pow(2)) / pc_logvar.exp(), dim=2)
        y_labels = torch.argmax(y_one_hot, dim=1)
        pc_prior = torch.zeros_like(qc_probs)
        for i in range(batch_size):
            label = y_labels[i]
            start_cluster, end_cluster = label * self.clusters_per_class, (label + 1) * self.clusters_per_class
            pc_prior[i, start_cluster:end_cluster] = 1.0 / self.clusters_per_class
        
        kl_c = torch.sum(qc_probs * (torch.log(qc_probs + 1e-10) - torch.log(pc_prior + 1e-10)), dim=1)

        loss_per_c = torch.sum(qc_probs * (recon_loss + self.beta_1 * kl_z), dim=1)
        final_loss = torch.mean(loss_per_c + self.beta_2 * kl_c)

        return final_loss, torch.mean(recon_loss), self.beta_1 * torch.mean(kl_z), self.beta_2 * torch.mean(kl_c)

    def sample(self, y_label, num_samples=1):
        self.eval()
        y_label = int(y_label)
        with torch.no_grad():
            # Identify the clusters for the given class label
            start_cluster = y_label * self.clusters_per_class
            end_cluster = (y_label + 1) * self.clusters_per_class
            cluster_range = list(range(start_cluster, end_cluster))
            num_clusters_in_class = len(cluster_range)
            
            # Logic for balanced sampling from clusters
            if num_clusters_in_class == 0:
                return torch.empty(0, self.input_dim) 

            samples_per_cluster = num_samples // num_clusters_in_class
            remainder = num_samples % num_clusters_in_class

            c_idx_list = []
            for i, cluster_idx in enumerate(cluster_range):
                num_to_sample = samples_per_cluster
                if i < remainder:
                    num_to_sample += 1
                
                if num_to_sample > 0:
                    c_idx_list.extend([cluster_idx] * num_to_sample)
            
            c_idx = torch.tensor(c_idx_list, dtype=torch.long, device=next(self.parameters()).device)
            c_idx = c_idx[torch.randperm(len(c_idx))]
            
            c_cats = F.one_hot(c_idx, num_classes=self.c_dim).float()
            pc_mu, pc_logvar = self.pc_mean(c_cats), self.pc_logvar(c_cats)
            z = self.reparameterize(pc_mu, pc_logvar)
            
            recon_features = self.px_net(z)
            return recon_features

# util function
def validate_epoch(model, val_loader, device):
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for data, labels in val_loader:
            data, labels = data.to(device), labels.to(device)
            y_one_hot = F.one_hot(labels, num_classes=model.y_dim).float()
            loss, _, _, _ = model(data, y_one_hot)
            total_val_loss += loss.item()
    return total_val_loss / len(val_loader)

def get_cluster_centroids(model):
    model.eval()
    with torch.no_grad():
        all_clusters_one_hot = torch.eye(model.c_dim).to(next(model.parameters()).device)
        centroids = model.pc_mean(all_clusters_one_hot)
    return centroids

def reconstruct_from_centroids(model, centroids):
    model.eval()
    with torch.no_grad():
        recon_features = model.px_net(centroids)
    return recon_features.cpu().numpy()

def get_clusters_for_label(class_label, c_dim, y_dim):
    if c_dim % y_dim != 0:
        raise ValueError("c_dim must be a multiple of y_dim for even cluster distribution.")
    if not (0 <= class_label < y_dim):
        raise ValueError(f"class_label must be an integer between 0 and {y_dim - 1}.")
    clusters_per_class = c_dim // y_dim
    start_cluster = class_label * clusters_per_class
    end_cluster = start_cluster + clusters_per_class
    return list(range(start_cluster, end_cluster))

def predict_cluster_for_input(model, input_point, input_label, device):
    """Predicts the cluster for a single tabular data point."""
    model.eval()
    with torch.no_grad():
        x = input_point.view(1, -1).to(device)
        
        y_one_hot = F.one_hot(torch.tensor([input_label]), num_classes=model.y_dim).float().to(device)

        x_y = torch.cat([x, y_one_hot], dim=1)
        cluster_logits = model.qc_net(x_y)
        cluster_probs = F.softmax(cluster_logits, dim=1).squeeze(0)
        predicted_cluster = torch.argmax(cluster_probs).item()

    return cluster_probs, predicted_cluster

def generate_counterfactuals_with_steps_centroid(model, input_point, input_label, target_label, all_centroids, device):
    model.eval()
    with torch.no_grad():
        # Get source cluster and latent z for the input
        _, source_cluster_idx = predict_cluster_for_input(model, input_point, input_label, device)
        
        x_flat = input_point.view(1, -1).to(device)
        y_source_one_hot = F.one_hot(torch.tensor([input_label]), num_classes=model.y_dim).float().to(device)
        c_source_one_hot = F.one_hot(torch.tensor([source_cluster_idx]), num_classes=model.c_dim).float().to(device)
        
        x_c_y_source = torch.cat([x_flat, c_source_one_hot, y_source_one_hot], dim=1)
        z_source = model.qz_mean(model.qz_net(x_c_y_source))

        # Get cluster indices and define tau values
        target_cluster_indices = get_clusters_for_label(target_label, model.c_dim, model.y_dim)
        tau_values = torch.arange(0.0, 1.01, 0.05).to(device)
        
        all_z_counterfactuals = []

        # Generate a trajectory of z's for each target cluster
        for target_cluster_idx in target_cluster_indices:
            target_centroid = all_centroids[target_cluster_idx]
            z_trajectory = (1.0 - tau_values.view(-1, 1)) * z_source + tau_values.view(-1, 1) * target_centroid
            all_z_counterfactuals.append(z_trajectory)

        # Decode all z's in a single batch and reshape
        batch_z = torch.cat(all_z_counterfactuals, dim=0)
        cf_points = model.px_net(batch_z) 
        
        num_clusters = len(target_cluster_indices)
        num_tau_steps = len(tau_values)
        final_shape = (num_clusters, num_tau_steps, model.input_dim)
        
        return cf_points.view(final_shape)

def plot_counterfactual_trajectories(original_point, cf_trajectories, feature_names=None, n_clusters_per_label=5):
    try:
        original_np = original_point.cpu().numpy()
    except:
        original_np = original_point
    try: 
        trajectories_np = cf_trajectories.detach().cpu().numpy()
    except:
        trajectories_np = cf_trajectories
    target_cluster_indices = list(range(0, n_clusters_per_label))
    num_clusters, num_tau_steps, num_features = trajectories_np.shape
    tau_values = np.arange(0.0, 1.01, 0.05)

    if feature_names is None:
        feature_names = [f'Feature {i}' for i in range(num_features)]
    
    num_rows = int(np.ceil(num_features / 4))
    fig, axes = plt.subplots(num_rows, 4, figsize=(16, num_rows * 3.5), squeeze=False)
    
    for i in range(num_features):
        row, col = i // 4, i % 4
        ax = axes[row, col]
        
        ax.axhline(original_np[i], color='black', linestyle='--', label='Original Value')
        
        for j in range(num_clusters):
            cluster_id = target_cluster_indices[j]
            ax.plot(tau_values, trajectories_np[j, :, i], marker='o', linestyle='-', markersize=4, label=f'Path to Cluster {cluster_id}')
            
        ax.set_title(feature_names[i])
        ax.set_xlabel('τ (Interpolation Step)')
        ax.set_ylabel('Feature Value')
        ax.grid(True, linestyle=':', alpha=0.6)
    
    handles, labels = axes[0,0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=num_clusters + 1, bbox_to_anchor=(0.5, 1.05))
    
    for i in range(num_features, num_rows * 4):
        row, col = i // 4, i % 4
        axes[row, col].set_visible(False)
        
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

def create_multifaceted_constraint_fn(reference_array, rules_list):
    reference_tensor = reference_array.clone()

    def constraint_fn(reconstructed_tensor):
        ref_tensor = reference_tensor.to(reconstructed_tensor.device)
        reconstructed_flat = reconstructed_tensor.flatten()
        total_penalty = torch.tensor(0.0, device=reconstructed_tensor.device)

        for rule_string in rules_list:
            parts = re.split(r'(>=|>)', rule_string)
            if len(parts) != 3:
                continue
            
            lhs_str, operator, rhs_str = parts
            rhs_str_stripped = rhs_str.strip()
            rhs = float(rhs_str_stripped)

            variable_terms = re.findall(r'([+\-]\s*y_(\d+))', lhs_str)
            if not (1 <= len(variable_terms) <= 2):
                continue
            
            term1_str, index1_str = variable_terms[0]
            index1 = int(index1_str)

            # Conditionally replace RHS for single-variable, zero-RHS rules
            if len(variable_terms) == 1 and rhs == 0.0:
                rhs = -1 * ref_tensor[index1] if '-' in rhs_str_stripped else ref_tensor[index1]
            
            # Construct the LHS using torch tensors for differentiability
            lhs_value = torch.tensor(0.0, device=reconstructed_tensor.device)
            value1 = reconstructed_flat[index1]
            lhs_value += -value1 if '-' in term1_str else value1

            if len(variable_terms) == 2:
                term2_str, index2_str = variable_terms[1]
                index2 = int(index2_str)
                value2 = ref_tensor[index2]
                lhs_value += -value2 if '-' in term2_str else value2

            if operator == '>':
                # Violation occurs if RHS - LHS >= 0
                violation = torch.relu(rhs - lhs_value + 1e-6)
            else: # operator == '>='
                # Violation occurs if RHS - LHS > 0
                violation = torch.relu(rhs - lhs_value)
            total_penalty += violation
        return total_penalty
        
    return constraint_fn

def generate_cf_with_correction(
    model, 
    input_point, 
    input_label, 
    target_label, 
    all_centroids, 
    constraint_fn,
    device,
    clf,
    correction_steps=5, 
    correction_lr=0.05
):
    """
    Generates a constrained counterfactual path by interpolating from the
    source latent vector directly to the target cluster centroid. Reverts
    to the uncorrected path if a correction step fails.
    """
    model.eval()
    
    # Get the source latent vector (z_source) once
    with torch.no_grad():
        _, source_cluster_idx = predict_cluster_for_input(model, input_point, input_label, device)
        x_flat = input_point.view(1, -1).to(device)
        y_source_one_hot = F.one_hot(torch.tensor([input_label]), num_classes=model.y_dim).float().to(device)
        c_source_one_hot = F.one_hot(torch.tensor([source_cluster_idx]), num_classes=model.c_dim).float().to(device)
        x_c_y_source = torch.cat([x_flat, c_source_one_hot, y_source_one_hot], dim=1)
        z_source = model.qz_mean(model.qz_net(x_c_y_source))

    # Get target cluster indices and define tau values
    target_cluster_indices = get_clusters_for_label(target_label, model.c_dim, model.y_dim)
    tau_values = torch.arange(0.0, 1.01, 0.05).to(device)
    
    all_corrected_paths = []

    # Loop through each target cluster to generate a path
    for target_cluster_idx in target_cluster_indices:
        
        with torch.no_grad():
            target_centroid = all_centroids[target_cluster_idx]
        
        path_points_for_this_cluster = []

        # Walk the path, correcting as needed
        for tau in tau_values:
            z_step_uncorrected = (1.0 - tau) * z_source + tau * target_centroid
            
            z_step_current = z_step_uncorrected.clone()
            
            # Check for constraint violation
            with torch.no_grad():
                x_step_decoded = model.px_net(z_step_current)
                violation_loss = constraint_fn(x_step_decoded)

            # If constraint is violated, attempt correct
            if violation_loss > 0:
                z_step_current.requires_grad_(True)
                optimizer_z = torch.optim.Adam([z_step_current], lr=correction_lr)

                for _ in range(correction_steps):
                    optimizer_z.zero_grad()
                    x_step_corrected_decoded = model.px_net(z_step_current)
                    loss = constraint_fn(x_step_corrected_decoded)
                    
                    if loss <= 0: break
                    
                    loss.backward()
                    optimizer_z.step()
                
                z_step_current = z_step_current.detach()

                # check if the constraint is now satisfied
                with torch.no_grad():
                    final_check_decoded = model.px_net(z_step_current)
                    final_violation = constraint_fn(final_check_decoded)
                    
                    # If still violated, revert to the original uncorrected point
                    if final_violation > 0 or (final_violation<= 0 and clf.predict(final_check_decoded.cpu().numpy().reshape(1, -1))[0] != target_label):
                        z_step_current = z_step_uncorrected
            
            # Decode the final z and add to this cluster's path
            with torch.no_grad():
                final_point_step = model.px_net(z_step_current)
                path_points_for_this_cluster.append(final_point_step)
        
        # Stack the points for this path and add to the list of all paths
        all_corrected_paths.append(torch.cat(path_points_for_this_cluster, dim=0))

    return torch.stack(all_corrected_paths, dim=0)
