import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist, pdist, squareform
import scoringrules as sr
import numpy as np
import pandas as pd
import os
import rpy2.robjects as ro
import rpy2.robjects.numpy2ri
import time
import random

from Real_Gen.ctgan_mod import CTGAN
from Real_Gen.tvae_mod import TVAE
from Real_Gen.arfpy import arf
from Real_Gen.hard_decision import lim_cluster_tabular_data
from sklearn.datasets import load_breast_cancer
from ucimlrepo import fetch_ucirepo 

from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import GenericDataLoader
from synthcity.metrics.eval_statistical import MaximumMeanDiscrepancy, JensenShannonDistance, WassersteinDistance, InverseKLDivergence, PRDCScore
from Metrics.ppr import compute_pprecision_precall

# Activate automatic conversion between R and numpy arrays
rpy2.robjects.numpy2ri.activate()

# Load the scoringRules package
ro.r('library(scoringRules)')
ro.r('library(mlbench)')

# Generators
class RBF:
    def __init__(self, gamma=1.0):
        self.gamma = gamma

    def __call__(self, X):
        pairwise_dists = torch.cdist(X, X).pow(2)
        return torch.exp(-self.gamma * pairwise_dists)

class MMDLoss(nn.Module):
    def __init__(self, kernel=RBF()):
        super().__init__()
        self.kernel = kernel

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))
        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY

class GaussianGenerator(nn.Module):
    def __init__(self, dim):
        super(GaussianGenerator, self).__init__()
        self.mean = nn.Parameter(torch.zeros(dim))
        self.cholesky_factor = nn.Parameter(torch.eye(dim))  # Start with identity for stability
        self.type = 'Gaussian'

    def forward(self, num_samples=1):
        eps = torch.randn(num_samples, self.mean.size(0), device=self.mean.device)
        cov_matrix = self.cholesky_factor @ self.cholesky_factor.T
        samples = self.mean + eps @ cov_matrix.T  # Reparameterization trick
        return samples

class BetaGenerator(nn.Module):
    def __init__(self, dim):
        super(BetaGenerator, self).__init__()
        # Learnable alpha and beta parameters, initialized to reasonable values
        self.alpha = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.ones(dim))
        self.type = 'Beta'

    def forward(self, num_samples=1):
        alpha_clamped = F.softplus(self.alpha) + 1e-6  # Ensure positive values
        beta_clamped = F.softplus(self.beta) + 1e-6
        # Kumaraswamy distribution Simulation
        u = torch.rand(num_samples, alpha_clamped.shape[0], device=self.alpha.device)
        samples = (1 - (1 - u).pow(1 / beta_clamped)).pow(1 / alpha_clamped)
        return samples
    
class MultivariateStudentTGenerator(nn.Module):
    def __init__(self, dim):
        super(MultivariateStudentTGenerator, self).__init__()
        self.mean = nn.Parameter(torch.zeros(dim))
        self.cholesky_factor = nn.Parameter(torch.eye(dim))
        self.dof = nn.Parameter(torch.tensor(8.0))  # Default degrees of freedom
        self.type = 'StudentT'

    def forward(self, num_samples=1):
        eps = torch.randn(num_samples, self.mean.size(0), device=self.mean.device)

        # Chi-squared distribution sample required for scaling
        chi_square = torch.distributions.Chi2(self.dof)
        chi_samples = chi_square.sample((num_samples,)).unsqueeze(-1)

        # Scale
        scale = torch.sqrt(chi_samples / self.dof)

        # Scale matrix via Cholesky decomposition
        scale_matrix = self.cholesky_factor @ self.cholesky_factor.T

        # Reparameterization
        transformed_samples = self.mean + (eps / scale) @ scale_matrix.T
        return transformed_samples

class GumbelGenerator(nn.Module):
    def __init__(self, dim):
        super(GumbelGenerator, self).__init__()
        self.location = nn.Parameter(torch.zeros(dim))  # Location parameter
        self.scale = nn.Parameter(torch.ones(dim))  # Scale parameter
        self.type = 'Gumbel'

    def forward(self, num_samples=1):
        # Sampling using reparameterization
        u = torch.rand(num_samples, len(self.location), device=self.location.device)
        samples = self.location - self.scale * torch.log(-torch.log(u))
        return samples


# Support Functions
def compute_energy_score_r(predictions, observations):
    """
    Calls R's scoringRules::es_sample() function from Python.
    
    Args:
        predictions (numpy array): An N x M array where each row is a sample of M predictions.
        observations (numpy array): A 1D array of N observations.

    Returns:
        float: The computed energy score.
    """
    # Convert NumPy arrays to R objects
    r_predictions = ro.r.matrix(predictions, nrow=predictions.shape[0], ncol=predictions.shape[1])
    r_observations = ro.FloatVector(observations)
    r_predictions = ro.r('t')(r_predictions)

    # Call the R function
    energy_score = ro.r("es_sample")(y = r_observations, dat = r_predictions)

    # Convert the result back to Python
    return np.array(energy_score)[0]

def energy_score(y, X):
    """
    Computes the energy score for a given observation and predictive samples.

    Parameters:
    - y: ndarray of shape (d,), observed value
    - X: ndarray of shape (M, d), predictive samples

    Returns:
    - Energy score (float)
    """
    X = np.atleast_2d(X)
    y = np.atleast_1d(y)
    M = X.shape[0]

    # First term: average L2 distance between X and y
    d1 = np.mean(np.linalg.norm(X - y, axis=1))

    # Second term: average pairwise L2 distance among predictive samples
    d2 = np.mean(pdist(X, metric='euclidean'))

    return d1 - 0.5 * d2

def gaussian_distance(y, X, sigma=1.0):
    """
    Computes the average Gaussian (RBF) distance between a point y and each row in X.

    Parameters:
    - y: ndarray of shape (d,), observed value
    - X: ndarray of shape (M, d), predictive samples
    - sigma: float, bandwidth parameter for the RBF kernel

    Returns:
    - Average Gaussian distance (float)
    """
    X = np.atleast_2d(X)
    y = np.atleast_1d(y)
    dists = np.linalg.norm(X - y, axis=1)
    gauss_dists = np.exp(-dists**2 / (2 * sigma**2))
    gauss_dists = gauss_dists.numpy()  # Ensure it's a NumPy array
    return np.mean(gauss_dists)

def compute_similarity(true, pred, similarity_measure='energy_score'):
    true_np = true.detach().numpy()
    pred_np = pred.detach().numpy()
    if similarity_measure == "energy_score":
        return 1 / compute_energy_score_r(pred_np, true_np)
        #return 1 / sr.energy_score(true_np, pred_np)
    elif similarity_measure == "new_es":
        return 1 / energy_score(true_np, pred_np)
    elif similarity_measure == "Euclidean":
        euclidean_distance = torch.norm(true - pred)
        return 1 / euclidean_distance if euclidean_distance.item() != 0 else float('inf')
    elif similarity_measure == "Gaussian":
        # Using a fixed sigma value for Gaussian distance
        sigma = np.sqrt(1 / (2 * mh_gamma))
        return 1 / gaussian_distance(true_np, pred_np, sigma)

def project_to_positive_definite(matrix):
    eigvals, eigvecs = torch.linalg.eigh(matrix)
    eigvals = torch.clamp(eigvals, min=1e-6)
    return eigvecs @ torch.diag(eigvals) @ eigvecs.T

def avoid_empty_weights(weights, lb=1e-2):
    while torch.min(weights) < lb:
        avg_gap = (lb - torch.min(weights)) / (weights.shape[0] - 1)
        min_index = torch.argmin(weights)
        weights -= avg_gap
        weights[min_index] = lb
    return weights

def metrics_evaluation(synthetic_data, real_data):
    # Assume both are tensors
    if not isinstance(synthetic_data, pd.DataFrame):
        synthetic_data_copy = pd.DataFrame(synthetic_data.detach().numpy())

    if not isinstance(real_data, pd.DataFrame):
        real_data_copy = pd.DataFrame(real_data.detach().numpy())

    # Initialize the data loaders
    synthetic_data_loader = GenericDataLoader(synthetic_data_copy)
    real_data_loader = GenericDataLoader(real_data_copy)

    # Initialize the metrics
    mmd = MMDLoss(kernel=RBF(gamma=mh_gamma))
    jsd = JensenShannonDistance()
    ikl = InverseKLDivergence()

    # Initialize a DataFrame to store results
    results = pd.DataFrame(index=['Mixture'], columns=["MMD", "JSD", "IKL", "P-PR", "P-RE"])
    # Calculate metrics for each generated dataset
    mmd_distance = mmd(synthetic_data, real_data)
    jsd_distance = jsd.evaluate(synthetic_data_loader, real_data_loader)
    ikl_distance = ikl.evaluate(synthetic_data_loader, real_data_loader)
    ppr, pre = compute_pprecision_precall(real_data.detach().numpy(), synthetic_data.detach().numpy())

    # Convert the results to a dictionary
    mmd_result = np.round(mmd_distance.item(),7)
    jsd_result = np.round(list(jsd_distance.values())[0],7)
    ikl_result = np.round(list(ikl_distance.values())[0],7)
    den_result = np.round(ppr,7)
    cov_result = np.round(pre,7)

    # Store the result in the DataFrame
    results.iloc[0] = [mmd_result, jsd_result, ikl_result, den_result, cov_result]
    return results

def rbf_median(X):
    pairwise_dists = torch.cdist(X, X)
    median_dist = torch.median(pairwise_dists[pairwise_dists != 0])
    return 1 / (2 * median_dist ** 2)

# Training algorithm SGD with MMD loss
def sgd_mmd(x, model, burnin=0, nstep=100, lr=1, num_samples = 1000, mh_gamma=1.0):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    mmd_loss_fn = MMDLoss(kernel=RBF(gamma=mh_gamma))

    mmd_list = []

    para_list = [[] for _ in range(len(list(model.parameters())))]

    for i in range(burnin + nstep):
        optimizer.zero_grad()  # Clear previous gradients
        generated_batch = model(num_samples=num_samples)
        # Compute the MMD loss between generated and real samples
        loss = mmd_loss_fn(x, generated_batch)
        # Backpropagation and optimization step if loss is positive
        # print(f"\rStep {i + 1}, Loss: {loss.item()}", end='', flush=True)
        loss.backward(retain_graph=True)
        optimizer.step()

        if burnin >= 1 and i >= burnin:
            with torch.no_grad():
                for pos_idx, param in enumerate(model.parameters()):
                    para_list[pos_idx].append(param.data.clone())

        if burnin >= 1 and i == burnin + nstep - 1:
            with torch.no_grad():
                for pos_idx, param in enumerate(model.parameters()):
                    para_list[pos_idx] = torch.stack(para_list[pos_idx], dim=0)
                    para_list[pos_idx] = torch.mean(para_list[pos_idx], dim=0)
                    param.copy_(para_list[pos_idx])

        mmd_list.append(loss.item())
    # print()
    return model, mmd_list

def assign_parameters(models, parameters):
    for model, params in zip(models, parameters):
        if model.type == 'StudentT':
            mean, cholesky_factor, dof = params
            model.mean.data = mean
            model.cholesky_factor.data = cholesky_factor
            model.dof.data = dof
        elif model.type == 'Beta':
            alpha, beta = params
            model.alpha.data = alpha
            model.beta.data = beta
        elif model.type == 'Gumbel':
            location, scale = params
            model.location.data = location
            model.scale.data = scale
        elif model.type == 'Gaussian':
            mean, cov = params
            # Assuming you have a similar Gaussian class
            model.mean.data = mean
            model.cholesky_factor.data = cov.cholesky()
        else:
            raise ValueError(f"Unknown model type: {model.type}")

def extract_parameters(models):
    parameters_list = []

    for model in models:
        if model.type == 'StudentT':
            params = (model.mean.data.clone(), model.cholesky_factor.data.clone(), model.dof.data.clone())
        elif model.type == 'Beta':
            params = (model.alpha.data.clone(), model.beta.data.clone())
        elif model.type == 'Gumbel':
            params = (model.location.data.clone(), model.scale.data.clone())
        elif model.type == 'Gaussian':
            params = (model.mean.data.clone(), model.cholesky_factor.data.clone())
        else:
            raise ValueError(f"Unknown model type: {model.type}")

        parameters_list.append(params)

    return parameters_list

def time_decision(time_budget, pretrain_proportion, cluster_k, X, list_of_dist, n_samples=1000, max_iterations=100, discrete_columns=[], epsilon = 1, mh_gamma=1.0):
    assert 0 <= pretrain_proportion <= 1
    unique_list_of_list = list(set(list_of_dist))
    list_of_models = []
    for i in range(len(unique_list_of_list)):
        if unique_list_of_list[i] == 'Gaussian':
            list_of_models.append(GaussianGenerator(n_features))
        elif unique_list_of_list[i] == 'Beta':
            list_of_models.append(BetaGenerator(n_features))
        elif unique_list_of_list[i] == 'Gumbel':
            list_of_models.append(GumbelGenerator(n_features))
        elif unique_list_of_list[i] == 'StudentT':
            list_of_models.append(MultivariateStudentTGenerator(n_features))
        elif unique_list_of_list[i] == 'CTGAN':
            list_of_models.append(CTGAN(embedding_dim=X[i].shape[0], batch_size=10, epochs=max_iterations, cuda=False))
        elif unique_list_of_list[i] == 'TVAE':
            list_of_models.append(TVAE(embedding_dim=X[i].shape[0], batch_size=10, epochs=max_iterations, cuda=False))
        elif unique_list_of_list[i] == 'ARF':
            list_of_models.append(arf.arf(epochs=max_iterations))
        elif unique_list_of_list[i] in synthcity_dist:
            list_of_models.append(Plugins().get(unique_list_of_list[i], n_iter=max_iterations))
        else:
            raise ValueError(f"Unknown distribution name: {unique_list_of_list[i]}")
        
    unique_list_of_time = dict()

    test_num = int(X.shape[0] / cluster_k)
    test_set = X[:test_num]

    for j, dist in enumerate(unique_list_of_list):
        start_time = time.time()
        if dist in well_known_dist:
            list_of_models[j].fit(test_set.detach().numpy(), discrete_columns=discrete_columns)
        elif dist in manual_dist:
            list_of_models[j], mmd_list_empty = sgd_mmd(test_set, list_of_models[j], nstep=max_iterations, lr=epsilon, num_samples=n_samples, mh_gamma=mh_gamma)
        elif dist in synthcity_dist:
            cluster_loader = GenericDataLoader(test_set.detach().numpy())
            list_of_models[j].fit(cluster_loader)
        else:
            raise ValueError(f"Unknown distribution name: {dist}")
        end_time = time.time()

        avg_time =  (end_time - start_time) / int(max_iterations)
        unique_list_of_time[dist] = avg_time

    time_per_ite = sum([list_of_dist.count(dist) * unique_list_of_time[dist] for dist in unique_list_of_list])

    step_1_time = time_budget * pretrain_proportion
    step_2_time = time_budget * (1 - pretrain_proportion)

    step_1_ite = int(step_1_time / time_per_ite) + 1
    step_2_ite = int(step_2_time / time_per_ite) + 1

    if step_1_ite <= 0:
        step_1_ite = 1
    if step_2_ite <= 0:
        step_2_ite = 1

    print(f'Step 1 Iterations: {step_1_ite}, Step 2 Iterations: {step_2_ite}')
    return step_1_ite, step_2_ite

def unit_time_dec(X, cluster_k, list_of_dist, n_samples=1000, max_iterations=100, discrete_columns=[], epsilon = 1, mh_gamma=1.0):
    unique_list_of_list = list(set(list_of_dist))
    list_of_models = []
    for i in range(len(unique_list_of_list)):
        if unique_list_of_list[i] == 'Gaussian':
            list_of_models.append(GaussianGenerator(n_features))
        elif unique_list_of_list[i] == 'Beta':
            list_of_models.append(BetaGenerator(n_features))
        elif unique_list_of_list[i] == 'Gumbel':
            list_of_models.append(GumbelGenerator(n_features))
        elif unique_list_of_list[i] == 'StudentT':
            list_of_models.append(MultivariateStudentTGenerator(n_features))
        elif unique_list_of_list[i] == 'CTGAN':
            list_of_models.append(CTGAN(embedding_dim=X[i].shape[0], batch_size=10, epochs=max_iterations, cuda=False))
        elif unique_list_of_list[i] == 'TVAE':
            list_of_models.append(TVAE(embedding_dim=X[i].shape[0], batch_size=10, epochs=max_iterations, cuda=False))
        elif unique_list_of_list[i] == 'ARF':
            list_of_models.append(arf.arf(epochs=max_iterations))
        elif unique_list_of_list[i] in synthcity_dist:
            list_of_models.append(Plugins().get(unique_list_of_list[i], n_iter=max_iterations))
        else:
            raise ValueError(f"Unknown distribution name: {unique_list_of_list[i]}")
        
    unique_list_of_time = dict()

    test_num = int(X.shape[0] / cluster_k)
    test_set = X[:test_num]

    for j, dist in enumerate(unique_list_of_list):
        start_time = time.time()
        if dist in well_known_dist:
            list_of_models[j].fit(test_set.detach().numpy(), discrete_columns=discrete_columns)
        elif dist in manual_dist:
            list_of_models[j], mmd_list_empty = sgd_mmd(test_set, list_of_models[j], nstep=max_iterations, lr=epsilon, num_samples=n_samples, mh_gamma=mh_gamma)
        elif dist in synthcity_dist:
            cluster_loader = GenericDataLoader(test_set.detach().numpy())
            list_of_models[j].fit(cluster_loader)
        else:
            raise ValueError(f"Unknown distribution name: {dist}")
        end_time = time.time()

        avg_time =  (end_time - start_time) / int(max_iterations)
        unique_list_of_time[dist] = avg_time

    max_time = max(unique_list_of_time.values())
    for key in unique_list_of_time:
        unique_list_of_time[key] = int(max_time / unique_list_of_time[key])

    return unique_list_of_time

# Bandit Gen Aloocation
def bandit(time_dict, cluster, list_of_dist, bandit_type='R-UCBE', n_samples=1000, n_features=2, budget=20, window_size=0.25, explore_para=1, discrete_columns=[], epsilon = 1, mh_gamma=1.0):
    n_components = len(list_of_dist)
    list_of_models = []
    for i in range(n_components):
        if list_of_dist[i] == 'Gaussian':
            list_of_models.append(GaussianGenerator(n_features))
        elif list_of_dist[i] == 'Beta':
            list_of_models.append(BetaGenerator(n_features))
        elif list_of_dist[i] == 'Gumbel':
            list_of_models.append(GumbelGenerator(n_features))
        elif list_of_dist[i] == 'StudentT':
            list_of_models.append(MultivariateStudentTGenerator(n_features))
        elif list_of_dist[i] == 'CTGAN':
            list_of_models.append(CTGAN(embedding_dim=cluster.shape[0], batch_size=10, epochs=1, cuda=False))
        elif list_of_dist[i] == 'TVAE':
            list_of_models.append(TVAE(embedding_dim=cluster.shape[0], batch_size=10, epochs=1, cuda=False))
        elif list_of_dist[i] == 'ARF':
            list_of_models.append(arf.arf(epochs=1))
        elif list_of_dist[i] in synthcity_dist:
            list_of_models.append(Plugins().get(list_of_dist[i], n_iter=1))
        else:
            raise ValueError(f"Unknown distribution name: {list_of_dist[i]}")

    if bandit_type == 'R-UCBE':
        assert window_size > 0 and window_size < 0.5
        bandit_list = [1e+10 for _ in range(n_components)]
        mmd_available_list = [[] for _ in range(n_components)]

        # Bandit Parameters Initializaiton
        pe = 0
        oe = 0
        ep = 0

        for j in range(budget):
            bandit_index = torch.argmax(torch.tensor(bandit_list)).item()
            if list_of_dist[bandit_index] in well_known_dist:
                list_of_models[bandit_index].epochs = int(time_dict[list_of_dist[bandit_index]])
                list_of_models[bandit_index].fit(cluster.detach().numpy(), discrete_columns=discrete_columns)
                synthetic_samples = torch.tensor(list_of_models[bandit_index].sample(n_samples))
            elif list_of_dist[bandit_index] in manual_dist:
                list_of_models[bandit_index], mmd_list_empty = sgd_mmd(cluster, list_of_models[bandit_index], nstep=int(time_dict[list_of_dist[bandit_index]]), lr=epsilon, num_samples=n_samples, mh_gamma=mh_gamma)
                synthetic_samples = list_of_models[bandit_index].forward(num_samples=n_samples)
            elif list_of_dist[bandit_index] in synthcity_dist:
                if list_of_dist[bandit_index] == 'ddpm':
                    list_of_models[bandit_index].model.n_iter = int(time_dict[list_of_dist[bandit_index]])
                else:
                    list_of_models[bandit_index].n_iter = int(time_dict[list_of_dist[bandit_index]])
                cluster_loader = GenericDataLoader(cluster.detach().numpy())
                list_of_models[bandit_index].fit(cluster_loader)
                synthetic_samples = torch.tensor(list_of_models[bandit_index].generate(count=n_samples).numpy())
            mmd_loss_fn = MMDLoss(kernel=RBF(gamma=mh_gamma))
            reward = -1 * mmd_loss_fn(synthetic_samples, cluster)
            n = len(mmd_available_list[bandit_index])
            h = int(n * window_size)
            pe = sum(mmd_available_list[bandit_index][-h:]) / h if h >= 1 else 0
            oe = pe + sum([(budget - phi) * (mmd_available_list[bandit_index][phi] - mmd_available_list[bandit_index][phi - h] if phi - h >= 0 else 0) for phi in range(n-h, n)]) / (h**2) if h >= 1 else 0
            ep = np.sqrt(explore_para/(h**3)) / np.sqrt(budget - j + h) if h >= 1 else 0
            bandit_list[bandit_index] = oe + ep
            mmd_available_list[bandit_index].append(reward)
            
            print(f"\rIteration {j + 1}/{budget}, Gen: {list_of_dist[bandit_index]}, Bandit: {bandit_list[bandit_index]}", end='', flush=True)
        print()
        return list_of_dist[torch.argmax(torch.tensor(bandit_list)).item()]
    
    elif bandit_type == 'R-SR':
        def f_N(j, log_k):
            if j == 0:
                return 0
            else:
                return int(torch.ceil(torch.tensor((budget - n_components) / (log_k * (n_components + 1 - j)))))

        log_k = sum([1 / i for i in range(1, n_components + 1)]) + 1 / 2
        k_list = [i for i in range(n_components)]

        for j in range(1, n_components):
            nstep = f_N(j, log_k) - f_N(j-1, log_k)
            nstep = nstep if nstep >= 1 else 1
            pe = 0
            eliminated_gen_idx = None

            for pos_idx, gen_idx in enumerate(k_list):
                if list_of_dist[gen_idx] in well_known_dist:
                    mmd_list = []
                    for s in range(nstep):
                        list_of_models[gen_idx].epochs = int(time_dict[list_of_dist[gen_idx]] * (s + 1))
                        list_of_models[gen_idx].fit(cluster.detach().numpy(), discrete_columns=discrete_columns)
                        synthetic_samples = torch.tensor(list_of_models[gen_idx].sample(n_samples))
                        mmd_loss_fn = MMDLoss(kernel=RBF(gamma=mh_gamma))
                        reward = mmd_loss_fn(synthetic_samples, cluster)
                        mmd_list.append(reward)
                elif list_of_dist[gen_idx] in manual_dist:
                    list_of_models[gen_idx], mmd_list= sgd_mmd(cluster, list_of_models[gen_idx], nstep=int(time_dict[list_of_dist[gen_idx]] * nstep), lr=epsilon, num_samples=n_samples, mh_gamma=mh_gamma)
                elif list_of_dist[gen_idx] in synthcity_dist:
                    mmd_list = []
                    for s in range(nstep):
                        if list_of_dist[gen_idx] == 'ddpm':
                            list_of_models[gen_idx].model.n_iter = int(time_dict[list_of_dist[gen_idx]] * (s + 1))
                        else:
                            list_of_models[gen_idx].n_iter = int(time_dict[list_of_dist[gen_idx]] * (s + 1))
                        cluster_loader = GenericDataLoader(cluster.detach().numpy())
                        list_of_models[gen_idx].fit(cluster_loader)
                        synthetic_samples = torch.tensor(list_of_models[gen_idx].generate(count=n_samples).numpy())
                        mmd_loss_fn = MMDLoss(kernel=RBF(gamma=mh_gamma))
                        reward = mmd_loss_fn(synthetic_samples, cluster)
                        mmd_list.append(reward)
                else:
                    raise ValueError(f"Unknown distribution name: {list_of_dist[gen_idx]}")
                reward_mean = -1 * sum(mmd_list) / len(mmd_list)
                if reward_mean <= pe:
                    pe = reward_mean
                    eliminated_gen_idx = pos_idx
                    el_gen = gen_idx
            k_list.pop(eliminated_gen_idx)
            print(f"\rIteration {j}/{n_components - 1}, Eliminated Gen: {list_of_dist[el_gen]}", end='', flush=True)
        print()

        return list_of_dist[k_list[0]]
    
    else:
        raise ValueError(f"Unknown bandit type: {bandit_type}")

def pretrain(X, list_of_dist, n_samples=1000, max_iterations=100, discrete_columns=[], epsilon = 1, mh_gamma=1.0):
    assert len(X) == len(list_of_dist)
    n_components = len(X)
    list_of_models = []
    for i in range(n_components):
        if list_of_dist[i] == 'Gaussian':
            list_of_models.append(GaussianGenerator(n_features))
        elif list_of_dist[i] == 'Beta':
            list_of_models.append(BetaGenerator(n_features))
        elif list_of_dist[i] == 'Gumbel':
            list_of_models.append(GumbelGenerator(n_features))
        elif list_of_dist[i] == 'StudentT':
            list_of_models.append(MultivariateStudentTGenerator(n_features))
        elif list_of_dist[i] == 'CTGAN':
            list_of_models.append(CTGAN(embedding_dim=X[i].shape[0], batch_size=10, epochs=max_iterations, cuda=False))
        elif list_of_dist[i] == 'TVAE':
            list_of_models.append(TVAE(embedding_dim=X[i].shape[0], batch_size=10, epochs=max_iterations, cuda=False))
        elif list_of_dist[i] == 'ARF':
            list_of_models.append(arf.arf(epochs=max_iterations))
        elif list_of_dist[i] in synthcity_dist:
            list_of_models.append(Plugins().get(list_of_dist[i], n_iter=max_iterations))
        else:
            raise ValueError(f"Unknown distribution name: {list_of_dist[i]}")

    for j, cluster in enumerate(X):
        print(f"Pretrain Cluster {j + 1}, Gen: {list_of_dist[j]}")
        if list_of_dist[j] in well_known_dist:
            list_of_models[j].fit(cluster.detach().numpy(), discrete_columns=discrete_columns)
        elif list_of_dist[j] in manual_dist:
            list_of_models[j], mmd_list_empty = sgd_mmd(cluster, list_of_models[j], nstep=max_iterations, lr=epsilon, num_samples=n_samples, mh_gamma=mh_gamma)
        elif list_of_dist[j] in synthcity_dist:
            cluster_loader = GenericDataLoader(cluster.detach().numpy())
            list_of_models[j].fit(cluster_loader)
        else:
            raise ValueError(f"Unknown distribution name: {list_of_dist[j]}")

    return list_of_models

def generate_synthetic_data(X, n_samples, num_samples_comparison, weights, similarity_measure, burnin, nstep, list_of_models, max_iterations=1, batch_size=100, discrete_columns=[], mh_gamma=1.0):
    def nstep_lr_decide(model_type):
        if model_type == 'Gaussian':
            nstep = 50
            lr = 1
        elif model_type == 'Beta':
            nstep = 10
            lr = 30
        elif model_type == 'Gumbel':
            nstep = 50
            lr = 1
        elif model_type == 'StudentT':
            nstep = 1
            lr = 1
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        
        return nstep, lr

    mmd_values = []
    synthetic_samples = X
    n_components = len(list_of_models)

    csv_path = os.path.join(plot_dir, 'Optim_Weights.csv')
    if os.path.exists(csv_path):
        os.remove(csv_path)
    open(csv_path, 'w').close()

    for iteration in range(max_iterations):
        if iteration == 0:
            optim_weights = weights
            component_indices = torch.multinomial(optim_weights, n_samples, replacement=True)
            optim_num_list = [(component_indices == idx).nonzero(as_tuple=True)[0] for idx in range(len(list_of_models))]
            optim_num_list = [indices.size(dim=0) for indices in optim_num_list]
            synthetic_samples = []
            for j in range(len(list_of_models)):
                if list_of_models[j].type in well_known_dist:
                    synthetic_samples.append(torch.tensor(list_of_models[j].sample(optim_num_list[j])))
                elif list_of_models[j].type in manual_dist:
                    synthetic_samples.append(list_of_models[j].forward(num_samples=optim_num_list[j]))
                elif list_of_models[j].__class__.__name__[-6:] == 'Plugin':
                    synthetic_samples.append(torch.tensor(list_of_models[j].generate(count=optim_num_list[j]).numpy()))

            plt.scatter(X.detach().numpy()[:, 0], X.detach().numpy()[:, 1], color="blue", label="Train Data", alpha=0.5)
            for i in range(len(synthetic_samples)):
                plt.scatter(synthetic_samples[i].detach().numpy()[:, 0], synthetic_samples[i].detach().numpy()[:, 1], color=color_list[i], label=f"Generated Cluster {i + 1}", alpha=0.2)
            plt.legend()
            plt.xlabel("X1")
            plt.ylabel("X2")
            plt.title("Mixture Model Sample Data")
            plt.savefig(f'{plot_dir}/Original_and_Synthetic_Data_After_Pretrain.png')
            plt.clf()

        synthetic_samples_list = []
        for j in range(len(list_of_models)):
            if list_of_models[j].type in well_known_dist:
                synthetic_samples_list.append(torch.tensor(list_of_models[j].sample(num_samples_comparison)))
            elif list_of_models[j].type in manual_dist:
                synthetic_samples_list.append(list_of_models[j].forward(num_samples=num_samples_comparison))
            elif list_of_models[j].__class__.__name__[-6:] == 'Plugin':
                synthetic_samples_list.append(torch.tensor(list_of_models[j].generate(count=num_samples_comparison).numpy()))

        # Membership probility matrix computation
        membership_probabilities = torch.zeros((n_samples, n_components))
        for i in range(n_samples):
            for j in range(n_components):
                similarity = compute_similarity(X[i], synthetic_samples_list[j], similarity_measure)
                membership_probabilities[i, j] = similarity * weights[j]
        membership_probabilities = torch.clamp(membership_probabilities, min=1e-3)  # Avoid zero probabilities

        # Update parameters
        row_sums = membership_probabilities.sum(dim=1, keepdim=True)
        membership_probabilities = membership_probabilities / row_sums

        col_sums = membership_probabilities.sum(dim=0, keepdim=True)
        membership_probabilities_batch = membership_probabilities / col_sums
        for j in range(n_components):
            # idx = torch.argmax(membership_probabilities, dim=1)
            # X_batch = X[idx == j]
            batch_indices = torch.multinomial(membership_probabilities_batch[:, j], n_samples, replacement=True)
            X_batch = X[batch_indices]

            if X_batch.shape[0] == 0:
                # print(f"Skipping component {j} due to empty cluster.")
                continue
            
            if list_of_models[j].type in well_known_dist:
                list_of_models[j].epochs = int(nstep * (iteration + 1))
                list_of_models[j].fit(X_batch.detach().numpy(), discrete_columns=discrete_columns)
            elif list_of_models[j].type in manual_dist:
                model_nstep, model_lr = nstep_lr_decide(list_of_models[j].type)
                list_of_models[j], mmd_list_empty = sgd_mmd(X_batch, list_of_models[j], burnin=burnin, nstep=model_nstep, lr=model_lr, num_samples=batch_size, mh_gamma=mh_gamma)
            elif list_of_models[j].__class__.__name__[-6:] == 'Plugin':
                if list_of_models[j].__class__.__name__[:-6] == 'TabDDPM':
                    list_of_models[j].model.n_iter = int(nstep * (iteration + 1))
                else:
                    list_of_models[j].n_iter = int(nstep * (iteration + 1))
                cluster_loader = GenericDataLoader(X_batch.detach().numpy())
                list_of_models[j].fit(cluster_loader)
        
        # Calculate optim weights
        col_sums = membership_probabilities.sum(dim=0)
        weights = col_sums / n_samples
        optim_weights = col_sums / col_sums.sum()
        # optim_weights = avoid_empty_weights(optim_weights, lb=1e-2)

        component_indices = torch.multinomial(optim_weights, n_samples, replacement=True)
        optim_num_list = [(component_indices == idx).nonzero(as_tuple=True)[0] for idx in range(len(list_of_models))]
        optim_num_list = [indices.size(dim=0) for indices in optim_num_list]
        optim_num_list = [int(num) if num >= 1 else 1 for num in optim_num_list]

        synthetic_samples = []
        for j in range(len(list_of_models)):
            if optim_num_list[j] == 0:
                print(f"Skipping model {list_of_models[j].type} due to zero samples.")
                continue
            elif list_of_models[j].type in well_known_dist:
                synthetic_samples.append(torch.tensor(list_of_models[j].sample(optim_num_list[j])))
            elif list_of_models[j].type in manual_dist:
                synthetic_samples.append(list_of_models[j].forward(num_samples=optim_num_list[j]))
            elif list_of_models[j].__class__.__name__[-6:] == 'Plugin':
                synthetic_samples.append(torch.tensor(list_of_models[j].generate(count=optim_num_list[j]).numpy()))
        list_synthetic_samples = synthetic_samples.copy()
        synthetic_samples = torch.vstack(synthetic_samples)

        mmd_loss_fn = MMDLoss(kernel=RBF(gamma=mh_gamma))
        mmd = mmd_loss_fn(synthetic_samples, X)
        mmd_values.append(np.round(mmd.item(), 4))

        if iteration % 50 == 0:
            plt.scatter(X.detach().numpy()[:, 0], X.detach().numpy()[:, 1], color="blue", label="Train Data", alpha=0.5)
            for i in range(len(list_synthetic_samples)):
                plt.scatter(list_synthetic_samples[i].detach().numpy()[:, 0], list_synthetic_samples[i].detach().numpy()[:, 1], color=color_list[i], label=f"Generated Cluster {i + 1}", alpha=0.2)
            plt.legend()
            plt.xlabel("X1")
            plt.ylabel("X2")
            plt.title("Mixture Model Sample Data")
            plt.savefig(f'{plot_dir}/Original_and_Synthetic_Data_Ite{iteration + 1}.png')
            plt.clf()

        print(f"\rIteration {iteration + 1}/{max_iterations}: MMD = {mmd_values[-1].item()}, Optim_Weights = {optim_weights}", end='', flush=True)
        with open(csv_path, 'ab') as f:
            np.savetxt(f, [np.round(optim_weights.numpy(), 4)], delimiter=',')
        
    print()
    print(f'Weights: {optim_weights}')

    return {'samples': synthetic_samples, 'list_samples': list_synthetic_samples, 'weights': optim_weights, 'mmd_values': mmd_values}


# Exp Begin

# Load dataset from CSV file in the "Dataset" folder
htru2 = fetch_ucirepo(id=267)
dataset = htru2.data.features

# Convert the dataset to a PyTorch tensor
dataset = torch.tensor(dataset.values, dtype=torch.float32)
# Randomly select 1000 data points from the dataset
indices = torch.randperm(dataset.size(0))
dataset = dataset[indices]

n_samples = dataset.shape[0]
n_features = dataset.shape[1]
discrete_columns = [f'{i}' for i in range(n_features)]

pretrain_portion = 0.2
pretrain_dataset = dataset[:int(n_samples * pretrain_portion)]
train_dataset = dataset[int(n_samples * pretrain_portion):]

# Calculate the gamma for RBF with Median Heuristic
mh_gamma = rbf_median(dataset)

# Pretrain
labels, cluster_k = lim_cluster_tabular_data(pretrain_dataset.detach().numpy(), 10)
clusters = [pretrain_dataset[labels == i] for i in range(cluster_k)]

eval_max_ite = 50
bandit_type = 'R-SR'
manual_dist = ['Gaussian', 'Beta', 'Gumbel', 'StudentT']
well_known_dist = ['CTGAN', 'TVAE', 'ARF']
synthcity_dist = ['ctgan', 'great', 'pategan', 'tvae', 'dpgan', 'decaf', 'adsgan', 'nflow', 'rtvae', 'ddpm']
list_of_available_gen = ['ARF', 'ctgan', 'tvae', 'nflow', 'rtvae', 'ddpm']

avg_unit_ite = 100
unit_time_dict = unit_time_dec(train_dataset, cluster_k, list_of_available_gen, n_samples=n_samples, max_iterations=avg_unit_ite, discrete_columns=discrete_columns, epsilon = 1, mh_gamma=mh_gamma)
print("Unit Time Dictionary: ", unit_time_dict)
selected_gen = [bandit(unit_time_dict, cluster, list_of_available_gen, bandit_type=bandit_type, n_samples=100, n_features=n_features, 
                       budget=eval_max_ite, window_size=0.25, explore_para=0.1, discrete_columns=discrete_columns, epsilon = 1, mh_gamma=mh_gamma) for cluster in clusters]
print("Selected Generators: ", selected_gen)

# Budget Computation
budget = 5
avg_ite = 100
pretrain_budget = 0.2
pretrain_ite, num_of_ite = time_decision(time_budget=budget, pretrain_proportion=pretrain_budget, cluster_k=cluster_k, X=train_dataset, list_of_dist=selected_gen, n_samples=n_samples, max_iterations=avg_ite, discrete_columns=discrete_columns, epsilon = 1)
nstep_gap = 1
num_of_ite = num_of_ite // nstep_gap  # Reduce the number of iterations for the second step
print(f'Final Decision: Step 1 Iterations: {pretrain_ite}, Step 2 Iterations: {num_of_ite}')

# Pretrain
train_list_of_models = pretrain(clusters, selected_gen, n_samples=n_samples, max_iterations=pretrain_ite, discrete_columns=discrete_columns, epsilon = 1, mh_gamma=mh_gamma)
print('Train Models List: ', [model.type for model in train_list_of_models])
train_weights = torch.tensor([len(cluster) / pretrain_dataset.shape[0] for cluster in clusters])

plot_dir = 'Formal_demo/Record'
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

color_list = ['red', 'chocolate', 'green', 'purple', 'orange', 'yellow', 'pink', 'gray', 'black']
assert len(color_list) >= len(train_list_of_models), "Not enough colors for the number of models."

# Step 2
synthetic_data = generate_synthetic_data(X=train_dataset, n_samples=train_dataset.shape[0], num_samples_comparison=train_dataset.shape[0], 
                                         weights=train_weights, similarity_measure='new_es', burnin=10, nstep=nstep_gap, list_of_models=train_list_of_models, 
                                         max_iterations=num_of_ite, batch_size=100, discrete_columns=discrete_columns, mh_gamma=mh_gamma)

plt.scatter(train_dataset.detach().numpy()[:, 0], train_dataset.detach().numpy()[:, 1], color="blue", label="Train Data", alpha=0.5)
for i in range(len(synthetic_data['list_samples'])):
    plt.scatter(synthetic_data['list_samples'][i].detach().numpy()[:, 0], synthetic_data['list_samples'][i].detach().numpy()[:, 1], color=color_list[i], label=f"Generated Cluster {i + 1}", alpha=0.2)
plt.legend()
plt.xlabel("X1")
plt.ylabel("X2")
plt.title("Mixture Model Sample Data")
plt.savefig(f'{plot_dir}/Original_and_Synthetic_Data.png')
plt.clf()

# Plot MMD over iterations
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_of_ite + 1), synthetic_data['mmd_values'], marker='o', linestyle='-', color='b')
plt.xlabel('Iteration Number')
plt.ylabel('MMD Score')
plt.title('MMD Score over Iterations')
plt.grid(True)
plt.savefig(f'{plot_dir}/MMD_Score_over_Iterations.png')
plt.clf()

evl_result = metrics_evaluation(synthetic_data['samples'], train_dataset)
print(evl_result)
