import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
import numpy as np
from LBFGS import L_BFGS_B
from WeightedGaussNB import WeightedGaussNB
from tqdm import trange


class MCE(nn.Module):
    """
    the structure of multiple correlation encoder (MCE)
    """

    def __init__(self, input_dim, cond_dim, latent_dim, hidden_dim):
        super(MCE, self).__init__()

        torch.manual_seed(1234)

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )
        self.z = nn.Linear(hidden_dim, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        z = self.z(h)
        return z

    def decode(self, z, y):
        zy = torch.cat([z, y], dim=-1)
        x_recon = self.decoder(zy)
        return x_recon

    def forward(self, x, y):
        z = self.encode(x)
        x_recon = self.decode(z, y)
        return z, x_recon


def calculate_BCE(x_recon, x):
    """
    calculate the binary cross entropy between original attribute value vectors and reconstructed attribute value vectors
    """
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    return recon_loss


def pearson_corrcoef_matrix(x):
    if x.dim() != 2:
        raise ValueError("Input tensor must be 2-dimensional")
    mean_x = torch.mean(x, dim=0, keepdim=True)
    x_centered = x - mean_x
    std_x = torch.std(x, dim=0, unbiased=True, keepdim=True)
    std_x = torch.where(std_x == 0, torch.tensor(1e-9, device=x.device), std_x)
    x_normalized = x_centered / std_x
    corr_matrix = torch.mm(x_normalized.T, x_normalized) / (x.size(0) - 1)
    ans = torch.abs(corr_matrix)
    return ans


def calculate_L2(embeddings, labels):
    """
    calculate the correlation loss form the instance perspective
    """
    correlation = pearson_corrcoef_matrix(embeddings.T)
    correlation = torch.abs(correlation)
    correlation.fill_diagonal_(0)

    scale_factors = torch.ones(embeddings.shape[0])
    correlation_all = correlation * scale_factors.unsqueeze(1)
    corr_all = torch.sum(correlation_all) / (embeddings.shape[0] * (embeddings.shape[0] - 1))

    unique_labels = torch.unique(labels)
    total_loss = 0.0
    num_classes = len(unique_labels)
    for label in unique_labels:
        class_mask = (labels == label)
        class_embeddings = embeddings[class_mask]
        if class_embeddings.size(0) < 2:
            continue
        correlation = pearson_corrcoef_matrix(class_embeddings.T)
        correlation = torch.abs(correlation)
        correlation.fill_diagonal_(0)
        scale_factors = torch.ones(class_embeddings.shape[0])
        correlation_class = correlation * scale_factors.unsqueeze(1)
        corr_class = torch.sum(correlation_class) / (class_embeddings.shape[0] * (class_embeddings.shape[0] - 1))
        class_loss = corr_class / corr_all
        total_loss -= class_loss

    total_loss = total_loss / num_classes

    return total_loss


def pearson_correlation_coefficient(x, y):
    x_flat = x.view(-1)
    y_flat = y.view(-1)
    x_mean = torch.mean(x_flat)
    y_mean = torch.mean(y_flat)
    x_deviation = x_flat - x_mean
    y_deviation = y_flat - y_mean
    covariance = torch.sum(x_deviation * y_deviation)
    x_std = torch.sqrt(torch.sum(x_deviation ** 2))
    y_std = torch.sqrt(torch.sum(y_deviation ** 2))
    epsilon = 1e-9
    if x_std == 0 or y_std == 0:
        r = covariance / (x_std * y_std + epsilon)
    else:
        r = covariance / (x_std * y_std)
    return r


def calculate_L1(X, Y):
    """
    calculate the correlation loss form the attribute perspective
    """
    corr_list = []
    for i in range(X.shape[1]):
        col1 = X[:, i]
        corr_ac = pearson_correlation_coefficient(col1, Y)
        corr_ac = torch.abs(corr_ac)
        corr_aa_list = []
        for j in range(X.shape[1]):
            if i == j:
                continue
            col2 = X[:, j]
            corr_aa = pearson_correlation_coefficient(col1, col2)
            corr_aa_list.append(torch.abs(corr_aa))
        corr = corr_ac - torch.sum(torch.stack(corr_aa_list)) / (X.shape[1] - 1)
        corr_list.append(corr)
    correlation = torch.sum(torch.stack(corr_list)) / (X.shape[1])
    return -correlation


def calculate_KL(embeddings, labels):
    """
    calculate the KL divergence between the set of embedding vectors and the variational prior
    """
    unique_labels = torch.unique(labels)
    total_loss = 0.0
    num_classes = len(unique_labels)
    for label in unique_labels:
        class_mask = (labels == label)
        class_embeddings = embeddings[class_mask]
        if class_embeddings.size(0) < 2:
            num_classes -= 1
            continue
        kl = calculate_single_KL(class_embeddings)
        total_loss += kl
    total_loss /= num_classes
    return total_loss


def calculate_single_KL(matrix):
    mu = matrix.mean(dim=0)
    sigma_squared = matrix.var(dim=0, unbiased=False)
    kl_divs = 0.5 * ((mu ** 2) + sigma_squared - 1 - torch.log(sigma_squared))
    kl_mean = kl_divs.mean()
    return kl_mean


def run(X, Y, max_epoch=200, learning_rate=0.01, abla_var=0, random_seed=42):
    """
    train and test MCENB
    """

    # split the training set and the test set
    indices = np.arange(X.shape[0])
    x_tr, x_tst, indices_train, indices_test = train_test_split(X, indices, test_size=0.2, random_state=random_seed)
    train_idx = np.zeros(len(Y), dtype=bool)
    test_idx = np.zeros(len(Y), dtype=bool)
    for i in range(len(indices_train)):
        train_idx[indices_train[i]] = True
    for i in range(len(indices_test)):
        test_idx[indices_test[i]] = True
    X_train = X[train_idx]
    X_test = X[test_idx]
    Y_train = Y[train_idx]
    Y_test = Y[test_idx]
    Z_train = []

    # construct MCE
    input_dim = X.shape[1]
    cond_dim = len(np.unique(Y))
    hidden_dim = 2 * input_dim
    latent_dim = input_dim
    y_onehot = F.one_hot(torch.tensor(Y, dtype=torch.int64), num_classes=cond_dim).float()
    model = MCE(input_dim, cond_dim, latent_dim, hidden_dim)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # training process
    for epoch in trange(max_epoch):

        # attribute generation
        train_loss = 0
        z_epoch_list = []
        loss_epoch_list = []
        for i in range(len(X_train)):
            x = torch.tensor(X_train[i]).float()
            y = y_onehot[train_idx][i]
            z, x_recon = model(x, y)
            recon_loss = calculate_BCE(x_recon, x)
            loss_epoch_list.append(recon_loss)
            train_loss += recon_loss.item()
            z_epoch_list.append(z)

        train_loss /= len(Y)
        z_tensor = torch.stack(z_epoch_list, dim=0)
        z_tensor = torch.squeeze(z_tensor, dim=1)
        z_train = z_tensor.detach().numpy()
        Z_train = z_train
        y_tensor = torch.from_numpy(Y_train)

        # ablation variant setup
        if abla_var == 0 or abla_var == 1 or abla_var == 2:
            L_2 = calculate_L2(z_tensor, y_tensor)  # instance correlation loss
            L_elbo_1 = sum(loss_epoch_list) / len(X_train)  # binary cross entropy
            L_1 = calculate_L1(z_tensor, y_tensor.float())  # attribute correlation loss
            L_elbo_2 = calculate_KL(z_tensor, y_tensor)  # KL divergence
            loss = L_elbo_1 + L_elbo_2 + L_1 + L_2
        if abla_var == 3:
            L_2 = calculate_L2(z_tensor, y_tensor)
            L_elbo_1 = sum(loss_epoch_list) / len(X_train)
            L_elbo_2 = calculate_KL(z_tensor, y_tensor)
            loss = L_elbo_1 + L_elbo_2 + L_2
        elif abla_var == 4:
            L_elbo_1 = sum(loss_epoch_list) / len(X_train)
            L_1 = calculate_L1(z_tensor, y_tensor.float())
            L_elbo_2 = calculate_KL(z_tensor, y_tensor)
            loss = L_elbo_1 + L_elbo_2 + L_1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # classification process
    with torch.no_grad():
        z_epoch_list = []
        for i in range(len(X_test)):
            x = torch.tensor(X_test[i]).float()
            y = y_onehot[test_idx][i]
            z, x_recon = model(x, y)
            z_epoch_list.append(z)

        z_tensor = torch.stack(z_epoch_list, dim=0)
        z_tensor = torch.squeeze(z_tensor, dim=1)
        Z_test = z_tensor.detach().numpy()

    # ablation variant setup
    if abla_var == 0 or abla_var == 3 or abla_var == 4:
        gnb = WeightedGaussNB()
        gnb.fit(np.concatenate((X_train, Z_train), axis=1), Y_train)  # attribute augmentation
        w = L_BFGS_B(np.concatenate((X_train, Z_train), axis=1), Y_train)  # attribute weighting
        gnb.setWeight(w)
        score = gnb.score(np.concatenate((X_test, Z_test), axis=1), Y_test)
        if abla_var == 0:
            print(f"MCENB {100 * score:.2f}")
        elif abla_var == 3:
            print(f"MCENB_noL1 {100 * score:.2f}")
        elif abla_var == 4:
            print(f"MCENB_noL2 {100 * score:.2f}")
    elif abla_var == 1:
        gnb = WeightedGaussNB()
        gnb.fit(X_train, Y_train)
        w = L_BFGS_B(X_train, Y_train)
        gnb.setWeight(w)
        score = gnb.score(X_test, Y_test)
        print(f"MCENB_noA {100 * score:.2f}")
    elif abla_var == 2:
        gnb = GaussianNB()
        gnb.fit(np.concatenate((X_train, Z_train), axis=1), Y_train)
        score = gnb.score(np.concatenate((X_test, Z_test), axis=1), Y_test)
        print(f"MCENB_noW {100 * score:.2f}")

    return np.array([score])