import itertools
import numpy as np
import torch
torch.cuda.empty_cache()
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from prettytable import PrettyTable
import torch.optim.lr_scheduler as lr_scheduler
from itertools import product
import time
import csv
import os
import glob
from scipy.stats import entropy


class LightweightSolver:
    def __init__(self, Z_free, Z_star, Tis_free, T_star, X_obsv, d, n, K):
        self.d = d 
        self.n = n 
        self.K = K 
        self.Z_free = Z_free.cpu().detach()
        self.Z_star = Z_star.cpu().detach()
        self.Tis_free = Tis_free.cpu().detach()
        self.T_star = T_star.cpu().detach()
        self.X_obsv = X_obsv.cpu().detach()
    def present(self, loss):
        self.Z_dp = self.transform(self.constr(self.Z_free))
        self.T_dp = self.transform(self.constr(self.Tis_free).reshape(2, 2 ** self.d, self.n))
        self.Z_star_dp = self.transform(self.Z_star)
        self.T_star_dp = self.transform(self.T_star.reshape(2, 2 ** self.d, self.n))
        # mat_dp = self.sort_tensor(self.T_star)
        # print(mat_dp.shape)
        # mat_dp = mat_dp[0,:,:].transpose(0, 1) 
        self.print_estimated_parameters(loss)
    def compute_prediction(self):
        T_constr = self.constr(self.Tis_free)
        Z_constr = self.constr(self.Z_free)
        shape = (2,) * self.n + (self.K,)
        result = torch.zeros(shape)
        print("the shape of result:", result.shape)
        # Compute the outer product and element-wise multiplication
        for indices in product(range(2), repeat=self.d):
            # Compute R for the current indices
            R = T_constr[(slice(None),) + indices + (0,)]
            for i in range(self.n - 1):
                R = R.unsqueeze(-1) * T_constr[(slice(None),) + indices + (i+1,)]

            T = Z_constr[indices[0], :, 0]
            for d in range(1, self.d):
                T = T * Z_constr[indices[d], :, d]
            # Multiply R and T and accumulate into the result tensor
            result = result + R.unsqueeze(-1)  * T.view(*([1]*self.n), self.K)
        return result
    def sort_tensor(self, X):
        assert X.shape[0] == 2, "The first dimension of X must be 2"
        mask = X[0] > X[1]
        X_sorted = X.clone()
        X_sorted[0][mask], X_sorted[1][mask] = X[1][mask], X[0][mask]
        return X_sorted
    def transform(self, mat):
        mat_dp = self.sort_tensor(mat)
        mat_dp = mat_dp[0,:,:].transpose(0, 1) 
        mat_dp = torch.sort(mat_dp, dim=1).values
        mat_dp = torch.sort(mat_dp, dim=0).values
        return mat_dp
    def constr(self, mat):
        shape = mat.shape
        mat_rs = mat.view(2, -1)
        T_softmax = nn.functional.softmax(mat_rs, dim=0)
        return T_softmax.view(*shape)
    def print_estimated_parameters(self, loss):
        print("\nEstimated values of parameters vs. true ones (all, containing redundant ones):")

        t1 = PrettyTable(['Z table', 'True', 'Estimated'])
        for k in range(self.K):
            for zid in range(self.d):
                t1.add_row([f'P(Z_{zid+1}=0) in domain {k + 1}', f'{self.Z_star_dp[zid, k]:.5f}', f'{self.Z_dp[zid, k]:.5f}'])
                # Calculate the Mean Squared Error (MSE)
        # print("========== Z_star - Z_dp ==========")
        # print(self.Z_star_dp - self.Z_dp)
        # mse = np.mean((self.Z_star_dp - self.Z_dp) ** 2)

        # Adding the MSE to the PrettyTable
        # t1.add_row(['Mean Squared Error (MSE)', '', f'{mse:.5f}'])
        print(t1)

        t2 = PrettyTable([f'X|Z', 'True', 'Estimated'])
        for binary_config_list in list(itertools.product(*([(0, 1)] * self.d))):
            for nid in range(self.n):
                bin_config_str = ''.join(str(bit) for bit in binary_config_list)
                dec_config = int(bin_config_str, 2)
                t2.add_row([f'P(X_{nid + 1}=0 | Z={bin_config_str})', f'{self.T_star_dp[nid, dec_config]:.5f}', f'{self.T_dp[nid, dec_config]:.5f}'])
        # mse2 = np.mean((self.T_star_dp - self.T_dp) ** 2)

        # # Adding the MSE to the PrettyTable
        # t2.add_row(['Mean Squared Error (MSE)', '', f'{mse2:.5f}'])
        print(t2)        
        
        t3 = PrettyTable([f'X table', 'True', 'Estimated'])
        product_generator = itertools.product(*([(0, 1)] * n))

        # Use itertools.islice to get only the first 10 indices
        first_ten_indices = itertools.islice(product_generator, 5)

        # Iterate over and print the first ten indices
        predicted = self.compute_prediction()
        for indices in first_ten_indices:
        # for indices in itertools.product(*([(0, 1)] * self.n)):
            bin_config_str = ''.join(str(bit) for bit in indices)
            # predicted = self.compute_prediction(indices).detach().numpy()
            print("the shape of the two tensors")
            print(self.X_obsv.shape)
            print(predicted.shape)
            actual = self.X_obsv[tuple(list(indices) + [slice(None)])]
            pred = predicted[tuple(list(indices) + [slice(None)])]
            for k in range(self.K):
                t3.add_row(
                    [f'P(X={bin_config_str}) in domain {k + 1}', f'{actual[k]:.5f}', f'{pred[k]:.5f}'])
        print(t3)
        # Save PrettyTables to CSV
        self.save_table_to_csv(t1, 'estimated_parameters_Z.csv', loss, 'Z')
        self.save_table_to_csv(t2, 'estimated_parameters_XZ.csv', loss, 'XZ')
        self.save_table_to_csv(t3, 'estimated_X.csv', loss, 'X')

    def compute_loss(self):
        pred = self.compute_prediction()
        # print(torch.sum(pred, axis=-1, keepdims=True))
        # assert torch.sum(pred, axis=-1, keepdims=True) == torch.ones(pred.shape[-1])
        # assert torch.sum(self.X_obsv, axis=-1, keepdims=True) == torch.ones(pred.shape[-1])

        # Compute the KL divergence
        # Add a small value (epsilon) to avoid division by zero and log(0)
        epsilon = 1e-10
        kl_divergence = self.X_obsv * torch.log((self.X_obsv + epsilon) / (pred + epsilon))

        # Sum over the last axis where the distributions lie
        kl_divergence_sum = torch.sum(kl_divergence, axis=-1)

        # Mean KL divergence
        mean_kl_divergence = torch.mean(kl_divergence_sum)

        return mean_kl_divergence
    def local_grad(self, loss):
        final_Z_free = self.Z_free.clone().detach()
        final_Tis_free = self.Tis_free.clone().detach()
        
        gradients_around_point = self.check_gradients_around_point(self.compute_loss, final_Z_free, final_Tis_free)
        
        # Print gradients for inspection
        # for point, losses in gradients_around_point.items():
        #     print(f"{point}:")
        #     print(f"Loss Difference: {losses['Loss'] - loss}")
        t = PrettyTable([f'Point', 'Loss_Point - Loss_Final'])
        for point, losses in gradients_around_point.items():
            t.add_row([f"{point}:", f'{losses["Loss"] - loss}'])
        print(t)
        self.save_table_to_csv(t, 'loss_around_localsolution.csv', loss, None)

    def save_table_to_csv(self, pretty_table, filename, loss, R):
        print(f"The parameters are: N={self.n}_D={self.d}_K={self.K}")
        folder_path = f"/jet/home/yhan6/discrete/multidimensional_ind_binary/results3/N={self.n}_D={self.d}_K={self.K}"
        os.makedirs(folder_path, exist_ok=True)
        csv_file_path = f"{folder_path}/{filename}.csv"
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            if not file_exists:
                writer.writerow(pretty_table.field_names)  # Write header only if file does not exist
            else:
                writer.writerow([''] * len(pretty_table.field_names))
            hyperparam = f"Observed: {self.n}, Latent: {self.d}, Domain: {self.K}"
            formatted_loss = f"Loss: {loss }"
            writer.writerow([hyperparam])
            writer.writerow([formatted_loss])
            writer.writerows(pretty_table.rows)
        if R != None:
            outpath = f"{folder_path}/KL_result.csv"
            print(R)
            self.process_csv(csv_file_path, outpath, R)


    def calculate_kl_divergence(self, p_true, p_est):
        # Ensure no zero values to avoid log(0) issues
        p_true = np.clip(p_true, 1e-10, 1)
        p_est = np.clip(p_est, 1e-10, 1)
        return entropy(p_true, p_est)

    def process_csv(self, input_file, output_file, R):
        # Check if the file exists to determine if a header should be written
        file_exists = os.path.isfile(output_file)
        
        with open(input_file, 'r') as infile, open(output_file, 'a', newline='') as outfile:
            
            reader = csv.reader(infile)
            writer = csv.writer(outfile)

            kl_divergence_list = []

            loss = None
            p_true = []
            p_est = []
            if R == 'XZ':
                Init = "P(X_"
            elif R == 'Z':
                Init = "P(Z_"
            elif R == 'X':
                Init = 'P(X='
            # Write the header if the file is new
            writer.writerow(["Loss", "KL Divergence", R])
            
            for row in reader:
                
                if row and row[0].startswith("Loss:"):
                    loss = float(row[0].split(": ")[1])
                elif row and row[0].startswith(Init):
                    p_true.append(float(row[1]))
                    p_est.append(float(row[2]))
                elif not row:
                    if p_true and p_est:
                        kl_div = self.calculate_kl_divergence(np.array(p_true), np.array(p_est))
                        kl_divergence_list.append(kl_div)
                        writer.writerow([loss, kl_div])
                    p_true = []
                    p_est = []

            if p_true and p_est:
                kl_div = self.calculate_kl_divergence(np.array(p_true), np.array(p_est))
                kl_divergence_list.append(kl_div)
                writer.writerow([loss, kl_div])

            # Calculate average KL-divergence and add the status row
            average_kl_div = np.mean(kl_divergence_list)
            # status = "Pass" if average_kl_div <= 0.001 else "Fail"
            # writer.writerow(["Status", status])
            status = "Pass" if average_kl_div <= 0.0001 and any(value <= 0.0001 for value in kl_divergence_list) else "Fail"

            # Write the status to the CSV
            writer.writerow(["Status", status])   
    # def check_gradients_around_point(self, loss_fn, final_Z_free, final_Tis_free, perturbation_size=0.0001, num_points=30):
    #     losses = {}
        
    #     # Combine parameters into a single tensor for perturbation
    #     combined_params = torch.cat([final_Z_free.flatten(), final_Tis_free.flatten()])
        
    #     for i in range(num_points):
    #         perturbed_params = combined_params + torch.randn_like(combined_params) * perturbation_size
    #         perturbed_params.requires_grad_(True)
            
    #         # Split perturbed_params back into Z_free and Tis_free
    #         split_index = final_Z_free.numel()
    #         perturbed_Z_free = perturbed_params[:split_index].reshape(final_Z_free.shape)
    #         perturbed_Tis_free = perturbed_params[split_index:].reshape(final_Tis_free.shape)
    #         # Assign perturbed values
    #         self.Z_free.data = perturbed_Z_free
    #         self.Tis_free.data = perturbed_Tis_free
            
    #         # Compute loss and gradients
    #         loss = loss_fn()
            
    #         # Store gradients
    #         losses[f'Point_{i}'] = {
    #             'Loss': loss,
    #         }
            
    #         # Zero the gradients for the next iteration
    #         self.Z_free.grad.zero_()
    #         self.Tis_free.grad.zero_() 
class Simulator(object):
    def __init__(self, d, n, K):
        self.d = d
        self.n = n
        self.K = K
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device_count = torch.cuda.device_count()
        # print(f"Using device: {self.device}")
    def constr(self, mat):
        shape = mat.shape
        mat_rs = mat.view(2, -1)
        T_softmax = nn.functional.softmax(mat_rs, dim=0).to(self.device)
        return T_softmax.view(*shape)
    def sample_truth(self):
        # Directly create the tensors on the GPU, no need to call .to(self.device) multiple times.
        self.Z_star = torch.rand(self.d, self.K, device=self.device)
        self.T_star = torch.rand(self.n, 2 ** self.d, device=self.device)
        self.Z_tensor = torch.cat((self.Z_star.T.unsqueeze(0), 1 - self.Z_star.T.unsqueeze(0)), dim=0)
        self.T_tensor = torch.cat((self.T_star.T.unsqueeze(0), 1 - self.T_star.T.unsqueeze(0)), dim=0)

        shape = (2,) + (2,) * self.d + (self.n,)
        self.T_tensor = self.T_tensor.view(*shape)

        # Pre-allocate memory on GPU
        shape2 = (2,) * self.n + (self.K,)
        result = torch.zeros(shape2, device=self.device)

        self.T_pred = self.constr(self.T_tensor)  # Already on device, no need for another `.to(self.device)`
        self.Z_pred = self.constr(self.Z_tensor)  # Already on device

        # Compute the outer product and element-wise multiplication
        for indices in product(range(2), repeat=self.d):
            # Compute R for the current indices
            R = self.T_pred[(slice(None),) + indices + (0,)]
            for i in range(self.n - 1):
                R = R.unsqueeze(-1) * self.T_pred[(slice(None),) + indices + (i + 1,)]

            P = self.Z_pred[indices[0], :, 0]
            for dd in range(1, self.d):
                P = P * self.Z_pred[indices[dd], :, dd]

            # Multiply R and P and accumulate into the result tensor
            result += R.unsqueeze(-1) * P.view(*([1] * self.n), self.K)

        self.X_star = result
        print(f"============= X star ================")
        reshaped_tensor = self.X_star.reshape(2 ** self.n, self.K)
        print(reshaped_tensor[:20, :])
        print(f'd={self.d}, n={self.n}, K={self.K}')

    def solve_estimate(self, epochs=10000, lr=0.01, tr_Init=False, num_restarts=100, top_n=20):
        results = mp.Manager().list()
        processes = []
        for i in range(num_restarts):
            p = mp.Process(target=self.run_one_restart, args=(i, epochs, lr, tr_Init, results))
            processes.append(p)
            p.start()

        # Wait for all processes to complete
        for p in processes:
            p.join()

        # Sort results by loss (ascending order)
        results = list(results)  # Convert Manager().list() to normal list
        results.sort(key=lambda x: x[0])

        # Select the top_n configurations with the lowest loss
        top_results = results[:top_n]

        # Print out the top results
        for idx, (loss, solver) in enumerate(top_results):
            print(f"Top {idx + 1} - Loss: {loss}")
            solver.present(loss)
            # solver.local_grad(loss)

        return top_results
    def run_one_restart(self, restart_index, epochs, lr, tr_Init, results):
        """
        Function to run one restart on a specific GPU.
        """
        # Assign a GPU to this process based on the restart index
        gpu_id = restart_index % self.device_count  # Cycles through available GPUs
        torch.cuda.set_device(gpu_id)  # Set the GPU for this process

        print(f"Restart {restart_index + 1} using GPU {gpu_id}")

        # Initialize the solver with the specific parameters
        if tr_Init:
            solver = PyTorchSolver(self.d, self.n, self.K, self.X_star, self.Z_pred, self.T_pred, epochs, lr, self.Z_tensor, self.T_tensor)
        else:
            solver = PyTorchSolver(self.d, self.n, self.K, self.X_star, self.Z_pred, self.T_pred, epochs, lr)

        # Run the solver
        solver.solve_from_pytorch()

        # Compute the final loss
        final_loss = solver.compute_loss().item()
        lightweight_solver = solver.get_lightweight_solver()
        results.append((final_loss, lightweight_solver))
        # Store the final result in the shared list
        # results.append((final_loss, solver))
        # for i in range(num_restarts):
        #     print(f"Restart {i + 1}/{num_restarts}")

        #     if tr_Init:
        #         solver = PyTorchSolver(self.d, self.n, self.K, self.X_star, self.Z_pred, self.T_pred, epochs, lr, self.Z_tensor, self.T_tensor)
        #     else:
        #         solver = PyTorchSolver(self.d, self.n, self.K, self.X_star, self.Z_pred, self.T_pred, epochs, lr)

        #     solver.solve_from_pytorch()

        #     # Store the final loss and parameters
        #     final_loss = solver.compute_loss().item()
        #     results.append((final_loss, solver))
        
        # # Sort results by loss (ascending order)
        # results.sort(key=lambda x: x[0])

        # # Select the top_n configurations with the lowest loss
        # top_results = results[:top_n]

        # # Print out the top results
        # for idx, (loss, solve) in enumerate(top_results):
        #     print(f"Top {idx + 1} - Loss: {loss}")
        #     solve.present(loss)
        #     solve.local_grad(loss)
        #     # You can also save or process these results as needed

        # return top_results

class PyTorchSolver:
    def __init__(self, d, n, K, X_obsv, Z_star, T_star, epochs=10000, lr=0.01, flat_Z_star = None, flat_T_star = None):        
        self.d = d
        self.n = n
        self.K = K
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # print(f"Using device: {self.device}")
        # self.X_obsv = torch.tensor(X_obsv, dtype=torch.float32)
        self.X_obsv = X_obsv.to(self.device)
        # self.Z_star = torch.tensor(Z_star, dtype=torch.float32)
        # self.T_star = torch.tensor(T_star, dtype=torch.float32)
        self.Z_star = Z_star.clone().detach()
        self.T_star = T_star.clone().detach()
        self.epochs = epochs
        self.lr = lr
        if flat_Z_star == None:
            Z_free = torch.rand(1, K, d, dtype=torch.float32).to(self.device)
            self.Z_free = torch.cat((Z_free, 1 - Z_free), dim=0)
            self.Z_free.requires_grad_(True)
            shape = (1,) + (2,) * d + (n,)
            Tis_free = torch.rand(shape, dtype=torch.float32).to(self.device)
            self.Tis_free = torch.cat((Tis_free, 1 - Tis_free), dim=0)
            self.Tis_free.requires_grad_(True)
        else:
            combined_params = torch.cat([flat_Z_star.flatten(), flat_T_star.flatten()]).to(self.device)
            perturbed_params = combined_params + torch.randn_like(combined_params) * 0.01
            perturbed_params.requires_grad_(True)
            
            # Split perturbed_params back into Z_free and Tis_free
            split_index = flat_Z_star.numel()
            perturbed_Z_free = perturbed_params[:split_index].reshape(flat_Z_star.shape).to(self.device)
            perturbed_Tis_free = perturbed_params[split_index:].reshape(flat_T_star.shape).to(self.device)
            self.Z_free = perturbed_Z_free
            self.Z_free.requires_grad_(True)
            self.Tis_free = perturbed_Tis_free
            self.Tis_free.requires_grad_(True)

    def get_lightweight_solver(self):
        # Create a new lightweight instance with only the attributes you need
        lightweight_solver = LightweightSolver(self.Z_free, self.Z_star, self.Tis_free, self.T_star, self.X_obsv, self.d, self.n, self.K)
        return lightweight_solver
    def solve_from_pytorch(self):
        optimizer = optim.Adam([self.Z_free] + [self.Tis_free], lr=self.lr)
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        torch.autograd.set_detect_anomaly(True)
        print("Now solving...")
        for epoch in range(self.epochs):
            optimizer.zero_grad()
            loss = self.compute_loss()
            loss.backward()
            optimizer.step()
            # scheduler.step(loss)
            if epoch % 100 == 0:
                print(f'    Epoch {epoch}, Loss: {loss.item()}')
        print("Optimization finished.")
        # self.present(loss)
    def local_grad(self, loss):
        final_Z_free = self.Z_free.clone().detach().to(self.device)
        final_Tis_free = self.Tis_free.clone().detach().to(self.device)
        
        gradients_around_point = self.check_gradients_around_point(self.compute_loss, final_Z_free, final_Tis_free)
        
        # Print gradients for inspection
        # for point, losses in gradients_around_point.items():
        #     print(f"{point}:")
        #     print(f"Loss Difference: {losses['Loss'] - loss}")
        t = PrettyTable([f'Point', 'Loss_Point - Loss_Final'])
        for point, losses in gradients_around_point.items():
            t.add_row([f"{point}:", f'{losses["Loss"] - loss}'])
        print(t)
        self.save_table_to_csv(t, 'loss_around_localsolution.csv', loss, None)
        # Save PrettyTables to CSV
        # self.save_table_to_csv(t1, 'estimated_parameters_Z.csv', loss)
    def constr(self, mat):
        shape = mat.shape
        mat_rs = mat.view(2, -1)
        T_softmax = nn.functional.softmax(mat_rs, dim=0).to(self.device)
        return T_softmax.view(*shape)
    
    def compute_loss(self):
        pred = self.compute_prediction()
        pred = pred.to(self.device)
        # print(torch.sum(pred, axis=-1, keepdims=True))
        # assert torch.sum(pred, axis=-1, keepdims=True) == torch.ones(pred.shape[-1])
        # assert torch.sum(self.X_obsv, axis=-1, keepdims=True) == torch.ones(pred.shape[-1])

        # Compute the KL divergence
        # Add a small value (epsilon) to avoid division by zero and log(0)
        epsilon = 1e-10
        kl_divergence = self.X_obsv * torch.log((self.X_obsv + epsilon) / (pred + epsilon))

        # Sum over the last axis where the distributions lie
        kl_divergence_sum = torch.sum(kl_divergence, axis=-1)

        # Mean KL divergence
        mean_kl_divergence = torch.mean(kl_divergence_sum)

        return mean_kl_divergence
    def compute_prediction(self):
        T_constr = self.constr(self.Tis_free).to(self.device)
        Z_constr = self.constr(self.Z_free).to(self.device)
        shape = (2,) * self.n + (self.K,)
        result = torch.zeros(shape).to(self.device)
        # Compute the outer product and element-wise multiplication
        for indices in product(range(2), repeat=self.d):
            # Compute R for the current indices
            R = T_constr[(slice(None),) + indices + (0,)].to(self.device)
            for i in range(self.n - 1):
                R = R.unsqueeze(-1) * T_constr[(slice(None),) + indices + (i+1,)]

            T = Z_constr[indices[0], :, 0].to(self.device)
            for d in range(1, self.d):
                T = T * Z_constr[indices[d], :, d]
            # Multiply R and T and accumulate into the result tensor
            result = result + R.unsqueeze(-1)  * T.view(*([1]*self.n), self.K)
        return result
    def sort_tensor(self, X):
        assert X.shape[0] == 2, "The first dimension of X must be 2"
        X = X.to(self.device)
        mask = X[0] > X[1]
        X_sorted = X.clone()
        X_sorted[0][mask], X_sorted[1][mask] = X[1][mask], X[0][mask]
        return X_sorted
    def transform(self, mat):
        mat_dp = self.sort_tensor(mat)
        mat_dp = mat_dp.to(self.device)
        mat_dp = mat_dp[0,:,:].transpose(0, 1) 
        mat_dp = torch.sort(mat_dp, dim=1).values
        mat_dp = torch.sort(mat_dp, dim=0).values
        return mat_dp
    # def present(self, loss):
    #     self.Z_dp = self.transform(self.constr(self.Z_free))
    #     self.T_dp = self.transform(self.constr(self.Tis_free).reshape(2, 2 ** self.d, self.n))
    #     self.Z_star_dp = self.transform(self.Z_star)
    #     self.T_star_dp = self.transform(self.T_star.reshape(2, 2 ** self.d, self.n))
    #     # mat_dp = self.sort_tensor(self.T_star)
    #     # print(mat_dp.shape)
    #     # mat_dp = mat_dp[0,:,:].transpose(0, 1) 
    #     self.print_estimated_parameters(loss)

    def check_gradients_around_point(self, loss_fn, final_Z_free, final_Tis_free, perturbation_size=0.0001, num_points=30):
        losses = {}
        
        # Combine parameters into a single tensor for perturbation
        combined_params = torch.cat([final_Z_free.flatten(), final_Tis_free.flatten()])
        
        for i in range(num_points):
            perturbed_params = combined_params + torch.randn_like(combined_params) * perturbation_size
            perturbed_params.requires_grad_(True)
            
            # Split perturbed_params back into Z_free and Tis_free
            split_index = final_Z_free.numel()
            perturbed_Z_free = perturbed_params[:split_index].reshape(final_Z_free.shape)
            perturbed_Tis_free = perturbed_params[split_index:].reshape(final_Tis_free.shape)
            # Assign perturbed values
            self.Z_free.data = perturbed_Z_free
            self.Tis_free.data = perturbed_Tis_free
            
            # Compute loss and gradients
            loss = loss_fn()
            
            # Store gradients
            losses[f'Point_{i}'] = {
                'Loss': loss,
            }
            
            # Zero the gradients for the next iteration
            self.Z_free.grad.zero_()
            self.Tis_free.grad.zero_()

        return losses
    def print_estimated_parameters(self, loss):
        print("\nEstimated values of parameters vs. true ones (all, containing redundant ones):")

        t1 = PrettyTable(['Z table', 'True', 'Estimated'])
        for k in range(self.K):
            for zid in range(self.d):
                t1.add_row([f'P(Z_{zid+1}=0) in domain {k + 1}', f'{self.Z_star_dp[zid, k]:.5f}', f'{self.Z_dp[zid, k]:.5f}'])
                # Calculate the Mean Squared Error (MSE)
        # print("========== Z_star - Z_dp ==========")
        # print(self.Z_star_dp - self.Z_dp)
        # mse = np.mean((self.Z_star_dp - self.Z_dp) ** 2)

        # Adding the MSE to the PrettyTable
        # t1.add_row(['Mean Squared Error (MSE)', '', f'{mse:.5f}'])
        print(t1)

        t2 = PrettyTable([f'X|Z', 'True', 'Estimated'])
        for binary_config_list in list(itertools.product(*([(0, 1)] * self.d))):
            for nid in range(self.n):
                bin_config_str = ''.join(str(bit) for bit in binary_config_list)
                dec_config = int(bin_config_str, 2)
                t2.add_row([f'P(X_{nid + 1}=0 | Z={bin_config_str})', f'{self.T_star_dp[nid, dec_config]:.5f}', f'{self.T_dp[nid, dec_config]:.5f}'])
        # mse2 = np.mean((self.T_star_dp - self.T_dp) ** 2)

        # # Adding the MSE to the PrettyTable
        # t2.add_row(['Mean Squared Error (MSE)', '', f'{mse2:.5f}'])
        print(t2)        
        
        t3 = PrettyTable([f'X table', 'True', 'Estimated'])
        product_generator = itertools.product(*([(0, 1)] * n))

        # Use itertools.islice to get only the first 10 indices
        first_ten_indices = itertools.islice(product_generator, 50)

        # Iterate over and print the first ten indices
        predicted = self.compute_prediction()
        for indices in first_ten_indices:
        # for indices in itertools.product(*([(0, 1)] * self.n)):
            bin_config_str = ''.join(str(bit) for bit in indices)
            # predicted = self.compute_prediction(indices).detach().numpy()
            pred = predicted[tuple(list(indices) + [slice(None)])]
            actual = self.X_obsv[tuple(list(indices) + [slice(None)])]
            for k in range(self.K):
                t3.add_row(
                    [f'P(X={bin_config_str}) in domain {k + 1}', f'{actual[k]:.5f}', f'{pred[k]:.5f}'])
        print(t3)
        # Save PrettyTables to CSV
        self.save_table_to_csv(t1, 'estimated_parameters_Z.csv', loss, 'Z')
        self.save_table_to_csv(t2, 'estimated_parameters_XZ.csv', loss, 'XZ')
        self.save_table_to_csv(t3, 'estimated_X.csv', loss, 'X')
    def save_table_to_csv(self, pretty_table, filename, loss, R):
        print(f"The parameters are: N={self.n}_D={self.d}_K={self.K}")
        folder_path = f"/jet/home/yhan6/discrete/multidimensional_ind_binary/results/N={self.n}_D={self.d}_K={self.K}"
        os.makedirs(folder_path, exist_ok=True)
        csv_file_path = f"{folder_path}/{filename}.csv"
        file_exists = os.path.isfile(csv_file_path)
        with open(csv_file_path, mode='a', newline='') as file:
            writer = csv.writer(file)
            if not file_exists:
                writer.writerow(pretty_table.field_names)  # Write header only if file does not exist
            else:
                writer.writerow([''] * len(pretty_table.field_names))
            hyperparam = f"Observed: {self.n}, Latent: {self.d}, Domain: {self.K}"
            formatted_loss = f"Loss: {loss }"
            writer.writerow([hyperparam])
            writer.writerow([formatted_loss])
            writer.writerows(pretty_table.rows)
        if R != None:
            outpath = f"{folder_path}/KL_result.csv"
            print(R)
            self.process_csv(csv_file_path, outpath, R)


    def calculate_kl_divergence(self, p_true, p_est):
        # Ensure no zero values to avoid log(0) issues
        p_true = np.clip(p_true, 1e-10, 1)
        p_est = np.clip(p_est, 1e-10, 1)
        return entropy(p_true, p_est)

    def process_csv(self, input_file, output_file, R):
        # Check if the file exists to determine if a header should be written
        file_exists = os.path.isfile(output_file)
        
        with open(input_file, 'r') as infile, open(output_file, 'a', newline='') as outfile:
            
            reader = csv.reader(infile)
            writer = csv.writer(outfile)

            kl_divergence_list = []

            loss = None
            p_true = []
            p_est = []
            if R == 'XZ':
                Init = "P(X_"
            elif R == 'Z':
                Init = "P(Z_"
            elif R == 'X':
                Init = 'P(X='
            # Write the header if the file is new
            writer.writerow(["Loss", "KL Divergence", R])
            
            for row in reader:
                
                if row and row[0].startswith("Loss:"):
                    loss = float(row[0].split(": ")[1])
                elif row and row[0].startswith(Init):
                    p_true.append(float(row[1]))
                    p_est.append(float(row[2]))
                elif not row:
                    if p_true and p_est:
                        kl_div = self.calculate_kl_divergence(np.array(p_true), np.array(p_est))
                        kl_divergence_list.append(kl_div)
                        writer.writerow([loss, kl_div])
                    p_true = []
                    p_est = []

            if p_true and p_est:
                kl_div = self.calculate_kl_divergence(np.array(p_true), np.array(p_est))
                kl_divergence_list.append(kl_div)
                writer.writerow([loss, kl_div])

            # Calculate average KL-divergence and add the status row
            average_kl_div = np.mean(kl_divergence_list)
            # status = "Pass" if average_kl_div <= 0.001 else "Fail"
            # writer.writerow(["Status", status])
            status = "Pass" if average_kl_div <= 0.0001 and any(value <= 0.0001 for value in kl_divergence_list) else "Fail"

            # Write the status to the CSV
            writer.writerow(["Status", status])
if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)
    d = 3
    n = 3
    # assert n >= 2 ** d
    # K = max(1, 2 * 2 ** d - n) #  - 1
    K = 100
    tr_Init = False
    ss = Simulator(d, n, K)
    ss.sample_truth()
    start = time.time()
    ss.solve_estimate(epochs=8000, lr=0.03, tr_Init = tr_Init)
    end = time.time()
    print("time passed:", end - start)
    ## print out the likelihood of the points around the local solution
    ## use the true value as the initialization point and see how it goes

