from sklearn.linear_model import LinearRegression, Ridge
from sklearn.model_selection import KFold
import networkx as nx  
import numpy as np
import joblib
import pickle
import torch
import yaml
from scipy.stats import norm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import evaluation_utils as evut
import modularised_utils as mut


def get_coefficients(data, G, return_noise=False, use_ridge=False, alpha=1.0):
    """
    Estimates coefficients and computes noise (residuals) simultaneously.
    """
    nodes = list(G.nodes)
    coeffs = {}
    
    noise = data.copy()

    for node in nx.topological_sort(G):
        parents = list(G.predecessors(node))
        if not parents:
            continue
            
        node_idx = nodes.index(node)
        parent_indices = [nodes.index(p) for p in parents]
        
        y = data[:, node_idx]
        X = data[:, parent_indices]
        
        model = Ridge(alpha=alpha) if use_ridge else LinearRegression()
        model.fit(X, y)
        
        for parent, coef in zip(parents, model.coef_):
            coeffs[(parent, node)] = coef
            
        # Residuals = y - y_predicted
        y_pred = model.predict(X)
        noise[:, node_idx] = y - y_pred
        
    if return_noise:
        return coeffs, noise
    else:
        return coeffs

def compute_radius_lb(N, eta, c):
    """
    Computes the concentration radius epsilon_N(eta) lower bound

    Parameters:
    - N: int, number of samples
    - eta: float, confidence parameter (0 < eta < 1)
    - c: float, constant from the theorem (c > 1)

    Returns:
    - epsilon: float, the concentration bound
    """
    assert 0 < eta <= 1, "eta must be in (0, 1]"
    assert c > 1, "c must be greater than 1"
    return np.log(c / eta) / np.sqrt(N)

def compute_empirical_radius(N, eta, c1=1.0, c2=1.0, alpha=2.0, m=3):
    assert 0 < eta < 1, "eta must be in (0,1)"
    threshold = np.log(c1 / eta) / c2
    if N >= threshold:
        exponent = min(1/m, 0.5)
    else:
        exponent = 1 / alpha

    epsilon = (np.log(c1 / eta) / (c2 * N)) ** exponent
    return epsilon

def calculate_abstraction_error(T_matrix, Dll_test, Dhl_test):
    # 1. Estimate parameters from the low-level test data
    mu_L_test    = np.mean(Dll_test, axis=0)
    Sigma_L_test = np.cov(Dll_test, rowvar=False)

    # 2. Estimate parameters from the high-level test data
    mu_H_test    = np.mean(Dhl_test, axis=0)
    Sigma_H_test = np.cov(Dhl_test, rowvar=False)

    # 3. Transform the low-level parameters using the T matrix
    # This projects the low-level distribution into the high-level space
    mu_V_predicted    = mu_L_test @ T_matrix.T
    Sigma_V_predicted = T_matrix @ Sigma_L_test @ T_matrix.T
    
    # 4. Compute the Wasserstein distance between the two resulting Gaussians
    try:
        wasserstein_dist = np.sqrt(mut.compute_wasserstein(mu_V_predicted, Sigma_V_predicted, mu_H_test, Sigma_H_test))
    except Exception as e:
        print(f"  - Warning: Could not compute Wasserstein distance. Error: {e}. Returning NaN.")
        return np.nan

    return wasserstein_dist


def calculate_empirical_error(T_matrix, Dll_test, Dhl_test, metric='fro'):
    if Dll_test.shape[0] == 0 or Dhl_test.shape[0] == 0:
        return np.nan # Cannot compute error on empty data

    try:
        # 1. Transform the low-level data samples using the learned T matrix
        Dhl_predicted = Dll_test @ T_matrix.T
        
        # 2. Compute the distance between the predicted and actual data matrices.
        error = evut.compute_empirical_distance(Dhl_predicted.T, Dhl_test.T, metric)
        
    except Exception as e:
        print(f"  - Warning: Could not compute empirical distance. Error: {e}. Returning NaN.")
        return np.nan

    return error


def load_all_data(experiment_name):
    """Loads all model blueprints and abstraction data for a given experiment."""
    path = f"data/{experiment_name}"
    
    def safe_load_pickle(filepath):
        """Try multiple methods to load pickle files for compatibility."""
        import os
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"Pickle file not found: {filepath}")
        
        try:
            with open(filepath, 'rb') as f:
                first_bytes = f.read(50)
                if first_bytes.startswith(b'version https://git-lfs'):
                    raise IOError(
                        f"File {filepath} is a Git LFS pointer, not the actual file. "
                        f"Please run 'git lfs pull' to download the actual files, or "
                        f"ensure Git LFS is properly configured and the files are available."
                    )
        except (UnicodeDecodeError, IOError):
            pass  # Not a text file, continue with normal loading
        
        try:
            # First try with joblib (standard method)
            return joblib.load(filepath)
        except (KeyError, ValueError, EOFError, pickle.UnpicklingError) as e:
            print(f"Warning: joblib.load failed for {filepath}: {e}")
            # If joblib fails, try with pickle directly
            try:
                with open(filepath, 'rb') as f:
                    # Try with pickle.load directly
                    return pickle.load(f)
            except Exception as e2:
                print(f"Warning: pickle.load failed for {filepath}: {e2}")
                # Try with errors='ignore' for bytes decoding issues
                try:
                    with open(filepath, 'rb') as f:
                        unpickler = pickle.Unpickler(f)
                        # Set encoding for Python 2/3 compatibility if needed
                        if hasattr(unpickler, 'encoding'):
                            unpickler.encoding = 'latin1'
                        return unpickler.load()
                except Exception as e3:
                    raise IOError(f"Failed to load {filepath} with multiple methods. "
                                f"Original error: {e}, pickle error: {e2}, final error: {e3}")
    
    data = {
        'LLmodel': safe_load_pickle(f"{path}/LLmodel.pkl"),
        'HLmodel': safe_load_pickle(f"{path}/HLmodel.pkl"),
        'abstraction_data': safe_load_pickle(f"{path}/abstraction_data.pkl")
    }
    print(f"Data loaded for '{experiment_name}'.")

    return data

def prepare_cv_folds(observational_data, k, random_state, save_path):
    """Generates and saves K-Fold train/test indices."""
    kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
    num_samples = observational_data.shape[0]
    
    fold_indices = [{'train': train_idx, 'test': test_idx} 
                    for train_idx, test_idx in kf.split(np.arange(num_samples))]
    
    joblib.dump(fold_indices, save_path)
    print(f"Created and saved {len(fold_indices)} folds to '{save_path}'")
    return fold_indices

def assemble_fold_parameters(fold_indices, all_data, hyperparameters):
    """Assembles the final opt_params dictionary for a specific fold."""
    opt_params = hyperparameters.copy()

    opt_params['LLmodels']      = all_data['LLmodel'].get('scm_instances')
    opt_params['HLmodels']      = all_data['HLmodel'].get('scm_instances')
    opt_params['omega']         = all_data['abstraction_data']['omega']
    opt_params['experiment']    = all_data['experiment_name']
    opt_params['initial_theta'] = 'empirical'
    
    train_n  = len(fold_indices['train'])
    ll_bound = round(compute_radius_lb(N=train_n, eta=0.05, c=1000), 3)
    hl_bound = round(compute_radius_lb(N=train_n, eta=0.05, c=1000), 3)

    opt_params['theta_hatL'] = {
                                    'mu_U': all_data['LLmodel']['noise_dist']['mu'], 
                                    'Sigma_U': all_data['LLmodel']['noise_dist']['sigma'], 
                                    'radius': ll_bound
                                }
    opt_params['theta_hatH'] = {
                                    'mu_U': all_data['HLmodel']['noise_dist']['mu'], 
                                    'Sigma_U': all_data['HLmodel']['noise_dist']['sigma'], 
                                    'radius': hl_bound
                                }
    
    return opt_params

def assemble_barycentric_parameters(fold_info, all_data, baryca_hyperparams):
    """
    Assembles the final arguments dictionary specifically for barycentric_optimization.
    """
    opt_args = baryca_hyperparams.copy()

    opt_args['LLmodels'] = all_data['LLmodel'].get('scm_instances')
    opt_args['HLmodels'] = all_data['HLmodel'].get('scm_instances')
    opt_args['Ill'] = all_data['LLmodel']['intervention_set']
    opt_args['Ihl'] = all_data['HLmodel']['intervention_set']
    
    train_n  = len(fold_info['train'])
    ll_bound = round(compute_radius_lb(N=train_n, eta=0.05, c=1000), 3)
    hl_bound = round(compute_radius_lb(N=train_n, eta=0.05, c=1000), 3)

    opt_args['theta_L'] = {
        'mu_U': all_data['LLmodel']['noise_dist']['mu'], 
        'Sigma_U': all_data['LLmodel']['noise_dist']['sigma'], 
        'radius': ll_bound
    }
    opt_args['theta_H'] = {
        'mu_U': all_data['HLmodel']['noise_dist']['mu'], 
        'Sigma_U': all_data['HLmodel']['noise_dist']['sigma'], 
        'radius': hl_bound
    }
    
    return opt_args

def assemble_empirical_parameters(U_ll_train, U_hl_train, all_data, empirical_hyperparams):
    """
    Assembles the final arguments dictionary specifically for the empirical optimization.
    """
    opt_args = empirical_hyperparams.copy()

    opt_args['U_L'] = U_ll_train
    opt_args['U_H'] = U_hl_train
    opt_args['L_models'] = all_data['LLmodel'].get('scm_instances')
    opt_args['H_models'] = all_data['HLmodel'].get('scm_instances')
    opt_args['omega'] = all_data['abstraction_data']['omega']
    opt_args['experiment'] = all_data['experiment_name']
    
    return opt_args

def compute_struc_matrices(models, intervention_set):
    """Computes the reduced-form matrix F for each SCM instance."""
    return [torch.from_numpy(models[iota].F).float() for iota in intervention_set]


def print_distribution_summary(final_params, initial_params, name=""):
    """Prints a summary comparing initial and final Gaussian parameters."""
    
    mu_final = final_params['mu_U']
    sigma_final = final_params['Sigma_U']
    
    mu_initial = initial_params['mu_U']
    sigma_initial = initial_params['Sigma_U']
    
    print(f"\n--- Distribution Summary: {name} ---")
    
    # Compare Means
    print("\nMean (μ):")
    print(f"  - Initial: {np.round(mu_initial, 3)}")
    print(f"  - Final  : {np.round(mu_final, 3)}")
    
    # Compare Variances (diagonal of the covariance matrix)
    print("\nVariances (diag(Σ)):")
    print(f"  - Initial: {np.round(np.diag(sigma_initial), 3)}")
    print(f"  - Final  : {np.round(np.diag(sigma_final), 3)}")
    
    std_devs = np.sqrt(np.diag(sigma_final))
    # Add a small epsilon to avoid division by zero if variance is zero
    corr_matrix = sigma_final / np.outer(std_devs + 1e-8, std_devs + 1e-8)
    print("\nFinal Correlation Matrix:")
    print(np.round(corr_matrix, 2))
    print("-"*(25 + len(name)))
    
def plot_marginal_distributions(final_params, initial_params, var_names, model_name=""):
    """Plots a comparison of the 1D marginals for each variable."""
    
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman"],
        "font.size": 14,
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "legend.fontsize": 12,
        "figure.titlesize": 18
    })
    
    mu_final, sigma_final = final_params['mu_U'], final_params['Sigma_U']
    mu_initial, sigma_initial = initial_params['mu_U'], initial_params['Sigma_U']
    
    n_vars = len(var_names)
    n_cols = 3
    n_rows = (n_vars + n_cols - 1) // n_cols # Calculate rows needed
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(18, 5 * n_rows))
    axes = axes.flatten()
    fig.suptitle(f'Marginal Noise Distributions: {model_name}', fontsize=20)

    for i in range(n_vars):
        # Initial Distribution
        mean_i, std_i = mu_initial[i], np.sqrt(sigma_initial[i, i])
        x = np.linspace(mean_i - 3*std_i, mean_i + 3*std_i, 200)
        axes[i].plot(x, norm.pdf(x, mean_i, std_i), 'b-', lw=3, label='Initial (Empirical)')

        # Final (Worst-Case) Distribution
        mean_f, std_f = mu_final[i], np.sqrt(sigma_final[i, i])
        x = np.linspace(mean_f - 3*std_f, mean_f + 3*std_f, 200)
        axes[i].plot(x, norm.pdf(x, mean_f, std_f), 'r--', lw=3, label='Final (Worst-Case)')
        
        axes[i].set_title(var_names[i], fontsize=18)
        axes[i].legend(fontsize=14)
        axes[i].tick_params(axis='both', which='major', labelsize=14)

    # Hide any unused subplots
    for j in range(n_vars, len(axes)):
        fig.delaxes(axes[j])
        
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def create_optimization_animation(monitor, initial_params, var_names, model_level='L', filename='optimization.gif'):
    """Creates and saves an animation of the distribution's evolution."""
    
    # Select which history to use (Low-Level or High-Level)
    if model_level == 'L':
        mu_history = monitor.mu_L_history
        sigma_history = monitor.sigma_L_history
        initial_mu = initial_params['theta_hatL']['mu_U']
        initial_sigma = initial_params['theta_hatL']['Sigma_U']
    else:
        mu_history = monitor.mu_H_history
        sigma_history = monitor.sigma_H_history
        initial_mu = initial_params['theta_hatH']['mu_U']
        initial_sigma = initial_params['theta_hatH']['Sigma_U']

    num_epochs = len(mu_history)
    n_vars = len(var_names)
    fig, axes = plt.subplots(1, n_vars, figsize=(6 * n_vars, 5))
    if n_vars == 1: axes = [axes]
    fig.suptitle(f'Evolution of Worst-Case Distribution ({model_level} model)', fontsize=16)

    def update(epoch):
        for i in range(n_vars):
            ax = axes[i]
            ax.clear()

            # Plot Initial distribution (static blue line)
            mean_i, std_i = initial_mu[i], np.sqrt(initial_sigma[i, i] + 1e-8)
            x_i = np.linspace(mean_i - 4*std_i, mean_i + 4*std_i, 200)
            ax.plot(x_i, norm.pdf(x_i, mean_i, std_i), 'b-', lw=2, label='Initial (Empirical)')

            # Plot Final distribution (static red line)
            mean_f, std_f = mu_history[-1][i], np.sqrt(sigma_history[-1][i, i] + 1e-8)
            x_f = np.linspace(mean_f - 4*std_f, mean_f + 4*std_f, 200)
            ax.plot(x_f, norm.pdf(x_f, mean_f, std_f), 'r--', lw=2, label='Final (Worst-Case)')

            # Plot the CURRENT distribution for this epoch (moving green line)
            mean_e, std_e = mu_history[epoch][i], np.sqrt(sigma_history[epoch][i, i] + 1e-8)
            x_e = np.linspace(mean_e - 4*std_e, mean_e + 4*std_e, 200)
            ax.plot(x_e, norm.pdf(x_e, mean_e, std_e), 'g--', lw=2.5, label=f'Epoch {epoch+1}')
            
            ax.set_title(var_names[i])
            ax.set_ylim(bottom=0)
            ax.legend()
    
    ani = FuncAnimation(fig, update, frames=num_epochs, blit=False, repeat=False)
    
    print(f"Saving animation to {filename}...")
    ani.save(filename, writer='pillow', fps=5)
    plt.close() 
    print("Done.")
    return HTML(ani.to_jshtml()) 

def load_configs(config_files):

    configs = {}
    for var_name, file_path in config_files.items():
        try:
            with open(file_path, 'r') as f:
                configs[var_name] = yaml.safe_load(f)
        except FileNotFoundError:
            print(f"Warning: Config file {file_path} not found")
            configs[var_name] = {}
        except yaml.YAMLError as e:
            print(f"Error parsing {file_path}: {e}")
            configs[var_name] = {}
    
    return configs