import itertools
import numpy as np
import torch
torch.cuda.empty_cache()
import torch.nn as nn
import torch.optim as optim
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from prettytable import PrettyTable
import torch.optim.lr_scheduler as lr_scheduler
from itertools import product
import time
import csv
import os
import glob
from scipy.stats import entropy
class Simulator(object):
    def __init__(self, d, n, K):
        self.d = d
        self.n = n
        self.K = K
    def constr(self, mat):
        shape = mat.shape
        mat_rs = mat.view(2, -1)
        T_softmax = nn.functional.softmax(mat_rs, dim=0)
        return T_softmax.view(*shape)
    def sample_truth(self):
        self.Z_star = np.random.rand(self.d, self.K)       # (i, j) entry: P(Zi=0) in domain j
        self.T_star = np.random.rand(self.n, 2 ** self.d)  # (i, j) entry: P(Xi=0|Z=j) where j is in decimal
        self.Z_tensor = torch.from_numpy(self.Z_star.T).reshape((1,self.K,self.d))
        self.Z_tensor = torch.cat((self.Z_tensor, 1 - self.Z_tensor), dim=0)
        self.T_tensor = torch.from_numpy(self.T_star.T).reshape((1, 2**self.d,self.n))
        self.T_tensor = torch.cat((self.T_tensor, 1 - self.T_tensor), dim=0)
        shape = (2,) + (2,) * d + (n,)
        self.T_tensor = self.T_tensor.view(*shape)
        shape2 = (2,) * self.n + (self.K,)
        result = torch.zeros(shape2)
        self.T_pred = self.constr(self.T_tensor)
        self.Z_pred = self.constr(self.Z_tensor)
        # Compute the outer product and element-wise multiplication
        for indices in product(range(2), repeat=self.d):
            # Compute R for the current indices
            R = self.T_pred[(slice(None),) + indices + (0,)]
            for i in range(n - 1):
                R = R.unsqueeze(-1) * self.T_pred[(slice(None),) + indices + (i+1,)]

            P = self.Z_pred[indices[0], :, 0]
            for dd in range(1, d):
                P = P * self.Z_pred[indices[dd], :, dd]
            # Multiply R and T and accumulate into the result tensor
            result = result + R.unsqueeze(-1)  * P.view(*([1]*n), K)
        # self.X_star = self.tensor_T_star @ self.flat_Z_star
        self.X_star = result
        print("============= X star ================")
        reshaped_tensor = self.X_star.reshape(2**n, K)
        print(reshaped_tensor[:20,:])
        # print("============= Z star ================")
        # print(self.Z_star)
        print(f'd={self.d}, n={self.n}, K={self.K}')

    def solve_estimate(self, epochs=10000, lr=0.01, tr_Init=False, num_restarts=40, top_n=2):
        results = []

        for i in range(num_restarts):
            print(f"Restart {i + 1}/{num_restarts}")

            if tr_Init:
                solver = PyTorchSolver(self.d, self.n, self.K, self.X_star, self.Z_pred, self.T_pred, epochs, lr, self.Z_tensor, self.T_tensor)
            else:
                solver = PyTorchSolver(self.d, self.n, self.K, self.X_star, self.Z_pred, self.T_pred, epochs, lr)

            solver.solve_from_pytorch()

            # Store the final loss and parameters
            final_loss = solver.compute_loss().item()
            results.append((final_loss, solver))
        
        # Sort results by loss (ascending order)
        results.sort(key=lambda x: x[0])

        # Select the top_n configurations with the lowest loss
        top_results = results[:top_n]

        # Print out the top results
        for idx, (loss, solve) in enumerate(top_results):
            print(f"Top {idx + 1} - Loss: {loss}")
            solve.present(loss)
            solve.local_grad(loss)
            # You can also save or process these results as needed

        return top_results

class PyTorchSolver:
    def __init__(self, d, n, K, X_obsv, Z_star, T_star, epochs=10000, lr=0.01, flat_Z_star = None, flat_T_star = None):
        self.d = d
        self.n = n
        self.K = K
        # self.X_obsv = torch.tensor(X_obsv, dtype=torch.float32)
        self.X_obsv = X_obsv
        # self.Z_star = torch.tensor(Z_star, dtype=torch.float32)
        # self.T_star = torch.tensor(T_star, dtype=torch.float32)
        self.Z_star = Z_star.clone().detach()
        self.T_star = T_star.clone().detach()
        self.epochs = epochs
        self.lr = lr
        if flat_Z_star == None:
            Z_free = torch.rand(1, K, d, dtype=torch.float32)
            self.Z_free = torch.cat((Z_free, 1 - Z_free), dim=0)
            self.Z_free.requires_grad_(True)
            shape = (1,) + (2,) * d + (n,)
            Tis_free = torch.rand(shape, dtype=torch.float32)
            self.Tis_free = torch.cat((Tis_free, 1 - Tis_free), dim=0)
            self.Tis_free.requires_grad_(True)
        else:
            combined_params = torch.cat([flat_Z_star.flatten(), flat_T_star.flatten()])
            perturbed_params = combined_params + torch.randn_like(combined_params) * 0.05
            perturbed_params.requires_grad_(True)
            
            # Split perturbed_params back into Z_free and Tis_free
            split_index = flat_Z_star.numel()
            perturbed_Z_free = perturbed_params[:split_index].reshape(flat_Z_star.shape)
            perturbed_Tis_free = perturbed_params[split_index:].reshape(flat_T_star.shape)
            self.Z_free = torch.tensor(perturbed_Z_free, dtype=torch.float32)
            self.Z_free.requires_grad_(True)
            self.Tis_free = torch.tensor(perturbed_Tis_free, dtype=torch.float32)
            self.Tis_free.requires_grad_(True)

    def solve_from_pytorch(self):
        optimizer = optim.Adam([self.Z_free] + [self.Tis_free], lr=self.lr)
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        torch.autograd.set_detect_anomaly(True)
        print("Now solving...")
        for epoch in range(self.epochs):
            optimizer.zero_grad()
            loss = self.compute_loss()
            loss.backward()
            optimizer.step()
            # scheduler.step(loss)
            if epoch % 100 == 0:
                print(f'    Epoch {epoch}, Loss: {loss.item()}')
        print("Optimization finished.")
        # self.present(loss)
    def local_grad(self, loss):
        final_Z_free = self.Z_free.clone().detach()
        final_Tis_free = self.Tis_free.clone().detach()
        
        gradients_around_point = self.check_gradients_around_point(self.compute_loss, final_Z_free, final_Tis_free)
        
        # Print gradients for inspection
        # for point, losses in gradients_around_point.items():
        #     print(f"{point}:")
        #     print(f"Loss Difference: {losses['Loss'] - loss}")
        t = PrettyTable([f'Point', 'Loss_Point - Loss_Final'])
        for point, losses in gradients_around_point.items():
            t.add_row([f"{point}:", f'{losses["Loss"] - loss}'])
        print(t)
        self.save_table_to_csv(t, 'loss_around_localsolution.csv', loss, None)
        # Save PrettyTables to CSV
        # self.save_table_to_csv(t1, 'estimated_parameters_Z.csv', loss)
    def constr(self, mat):
        shape = mat.shape
        mat_rs = mat.view(2, -1)
        T_softmax = nn.functional.softmax(mat_rs, dim=0)
        return T_softmax.view(*shape)
    
    def compute_loss(self):
        pred = self.compute_prediction()
        # print(torch.sum(pred, axis=-1, keepdims=True))
        # assert torch.sum(pred, axis=-1, keepdims=True) == torch.ones(pred.shape[-1])
        # assert torch.sum(self.X_obsv, axis=-1, keepdims=True) == torch.ones(pred.shape[-1])

        # Compute the KL divergence
        # Add a small value (epsilon) to avoid division by zero and log(0)
        epsilon = 1e-10
        kl_divergence = self.X_obsv * torch.log((self.X_obsv + epsilon) / (pred + epsilon))

        # Sum over the last axis where the distributions lie
        kl_divergence_sum = torch.sum(kl_divergence, axis=-1)

        # Mean KL divergence
        mean_kl_divergence = torch.mean(kl_divergence_sum)

        return mean_kl_divergence
    def compute_prediction(self):
        T_constr = self.constr(self.Tis_free)
        Z_constr = self.constr(self.Z_free)
        shape = (2,) * self.n + (self.K,)
        result = torch.zeros(shape)
        # Compute the outer product and element-wise multiplication
        for indices in product(range(2), repeat=self.d):
            # Compute R for the current indices
            R = T_constr[(slice(None),) + indices + (0,)]
            for i in range(self.n - 1):
                R = R.unsqueeze(-1) * T_constr[(slice(None),) + indices + (i+1,)]

            T = Z_constr[indices[0], :, 0]
            for d in range(1, self.d):
                T = T * Z_constr[indices[d], :, d]
            # Multiply R and T and accumulate into the result tensor
            result = result + R.unsqueeze(-1)  * T.view(*([1]*self.n), self.K)
        return result
    def sort_tensor(self, X):
        assert X.shape[0] == 2, "The first dimension of X must be 2"
        mask = X[0] > X[1]
        X_sorted = X.clone()
        X_sorted[0][mask], X_sorted[1][mask] = X[1][mask], X[0][mask]
        return X_sorted
    def transform(self, mat):
        mat_dp = self.sort_tensor(mat)
        mat_dp = mat_dp[0,:,:].transpose(0, 1) 
        mat_dp = torch.sort(mat_dp, dim=1).values
        mat_dp = torch.sort(mat_dp, dim=0).values
        return mat_dp
    def present(self, loss):
        self.Z_dp = self.transform(self.constr(self.Z_free))
        self.T_dp = self.transform(self.constr(self.Tis_free).reshape(2, 2 ** self.d, self.n))
        self.Z_star_dp = self.transform(self.Z_star)
        self.T_star_dp = self.transform(self.T_star.reshape(2, 2 ** self.d, self.n))
        # mat_dp = self.sort_tensor(self.T_star)
        # print(mat_dp.shape)
        # mat_dp = mat_dp[0,:,:].transpose(0, 1) 
        self.print_estimated_parameters(loss)

    def check_gradients_around_point(self, loss_fn, final_Z_free, final_Tis_free, perturbation_size=0.0001, num_points=30):
        losses = {}
        
        # Combine parameters into a single tensor for perturbation
        combined_params = torch.cat([final_Z_free.flatten(), final_Tis_free.flatten()])
        
        for i in range(num_points):
            perturbed_params = combined_params + torch.randn_like(combined_params) * perturbation_size
            perturbed_params.requires_grad_(True)
            
            # Split perturbed_params back into Z_free and Tis_free
            split_index = final_Z_free.numel()
            perturbed_Z_free = perturbed_params[:split_index].reshape(final_Z_free.shape)
            perturbed_Tis_free = perturbed_params[split_index:].reshape(final_Tis_free.shape)
            # Assign perturbed values
            self.Z_free.data = perturbed_Z_free
            self.Tis_free.data = perturbed_Tis_free
            
            # Compute loss and gradients
            loss = loss_fn()
            
            # Store gradients
            losses[f'Point_{i}'] = {
                'Loss': loss,
            }
            
            # Zero the gradients for the next iteration
            self.Z_free.grad.zero_()
            self.Tis_free.grad.zero_()

        return losses
    def print_estimated_parameters(self, loss):
        print("\nEstimated values of parameters vs. true ones (all, containing redundant ones):")

        t1 = PrettyTable(['Z table', 'True', 'Estimated'])
        for k in range(self.K):
            for zid in range(self.d):
                t1.add_row([f'P(Z_{zid+1}=0) in domain {k + 1}', f'{self.Z_star_dp[zid, k]:.5f}', f'{self.Z_dp[zid, k]:.5f}'])
                # Calculate the Mean Squared Error (MSE)
        # print("========== Z_star - Z_dp ==========")
        # print(self.Z_star_dp - self.Z_dp)
        # mse = np.mean((self.Z_star_dp - self.Z_dp) ** 2)

        # Adding the MSE to the PrettyTable
        # t1.add_row(['Mean Squared Error (MSE)', '', f'{mse:.5f}'])
        print(t1)

        t2 = PrettyTable([f'X|Z', 'True', 'Estimated'])
        for binary_config_list in list(itertools.product(*([(0, 1)] * self.d))):
            for nid in range(self.n):
                bin_config_str = ''.join(str(bit) for bit in binary_config_list)
                dec_config = int(bin_config_str, 2)
                t2.add_row([f'P(X_{nid + 1}=0 | Z={bin_config_str})', f'{self.T_star_dp[nid, dec_config]:.5f}', f'{self.T_dp[nid, dec_config]:.5f}'])
        # mse2 = np.mean((self.T_star_dp - self.T_dp) ** 2)

        # # Adding the MSE to the PrettyTable
        # t2.add_row(['Mean Squared Error (MSE)', '', f'{mse2:.5f}'])
        print(t2)

        # Save PrettyTables to CSV
        self.save_table_to_csv(t1, 'estimated_parameters_Z.csv', loss, 'Z')
        self.save_table_to_csv(t2, 'estimated_parameters_XZ.csv', loss, 'XZ')

    def save_table_to_csv(self, pretty_table, filename, loss, R):
        print(f"The parameters are: N={self.n}_D={self.d}_K={self.K}")
        print(self.n, self.d, self.K)
        folder_path = f"results/local_4/N={self.n}_D={self.d}_K={self.K}"
        os.makedirs(folder_path, exist_ok=True)
        csv_file_path = f"{folder_path}/{filename}.csv"
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            if not file_exists:
                writer.writerow(pretty_table.field_names)  # Write header only if file does not exist
            else:
                writer.writerow([''] * len(pretty_table.field_names))
            hyperparam = f"Observed: {self.n}, Latent: {self.d}, Domain: {self.K}"
            formatted_loss = f"Loss: {loss }"
            writer.writerow([hyperparam])
            writer.writerow([formatted_loss])
            writer.writerows(pretty_table.rows)
        if R != None:
            outpath = f"{folder_path}/KL_result.csv"
            print(R)
            self.process_csv(csv_file_path, outpath, R)


    def calculate_kl_divergence(self, p_true, p_est):
        # Ensure no zero values to avoid log(0) issues
        p_true = np.clip(p_true, 1e-10, 1)
        p_est = np.clip(p_est, 1e-10, 1)
        return entropy(p_true, p_est)
    def kl_divergence_binary(self, p_true_0, p_est_0):
        # Clip probabilities to avoid log(0)
        p_true_0 = np.clip(p_true_0, 1e-10, 1)
        p_est_0 = np.clip(p_est_0, 1e-10, 1)
        

        kl = p_est_0 * np.log(p_est_0 / p_true_0) + (1 - p_est_0) * np.log((1 - p_est_0) / (1 - p_true_0))
        return kl


    def process_csv(self, input_file, output_file, R):
        # Check if the file exists to determine if a header should be written
        file_exists = os.path.isfile(output_file)
        
        with open(input_file, 'r') as infile, open(output_file, 'a', newline='') as outfile:
            
            reader = csv.reader(infile)
            writer = csv.writer(outfile)

            kl_divergence_list = []

            loss = None
            p_true = []
            p_est = []
            if R == 'XZ':
                Init = "P(X_"
            elif R == 'Z':
                Init = "P(Z_"
            
            # Write the header if the file is new
            writer.writerow(["Loss", "KL Divergence", R])
            
            for row in reader:
                print(not row)
                if row and row[0].startswith("Loss:"):
                    loss = float(row[0].split(": ")[1])
                elif row and row[0].startswith(Init):
                    print("row 1 is", np.array(float(row[1])), 'row 2 is', np.array(float(row[2])))
                    kl_div = self.kl_divergence_binary(np.array(float(row[1])), np.array(float(row[2])))
                    kl_divergence_list.append(kl_div)
                #     p_true.append(float(row[1]))
                #     p_est.append(float(row[2]))
                elif not row:
                #     if p_true and p_est:
                #         kl_div = self.calculate_kl_divergence(np.array(p_true), np.array(p_est))
                #         kl_divergence_list.append(kl_div)
                    print("haha")
                    kl_div = np.mean(kl_divergence_list)
                    writer.writerow([loss, kl_div])
                #     p_true = []
                #     p_est = []
            print(kl_divergence_list)
            # if p_true and p_est:
            #     kl_div = self.calculate_kl_divergence(np.array(p_true), np.array(p_est))
            #     kl_divergence_list.append(kl_div)
            #     writer.writerow([loss, kl_div])

            # Calculate average KL-divergence and add the status row
            average_kl_div = np.mean(kl_divergence_list)
            # status = "Pass" if average_kl_div <= 0.001 else "Fail"
            # writer.writerow(["Status", status])
            status = "Pass" if average_kl_div <= 0.0001 and any(value <= 0.0001 for value in kl_divergence_list) else "Fail"

            # Write the status to the CSV
            writer.writerow(["Status", status, "KL", average_kl_div])
if __name__ == '__main__':
    print("==========haha==========")
    d = 4
    n = 3
    tr_Init = True
    K = 6
    ss = Simulator(d, n, K)
    ss.sample_truth()
    start = time.time()
    ss.solve_estimate(epochs=8000, lr=0.001, tr_Init = tr_Init)
    end = time.time()
    print("time passed:", end - start)
 