import math
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression as lr
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(42)


def data_generator(T=10000):
    T = T+2

    coefs_dict = dict()
    coefs_dict["A->A"] = 0.20386037
    coefs_dict["A->B"] = 0.71287403
    coefs_dict["A->C"] = 0.73596003
    coefs_dict["A->D"] = 0.2724924
    coefs_dict["A->X"] = 0.51875674
    coefs_dict["D->E"] = 0.67508935
    coefs_dict["D->D"] = 0.50960556
    coefs_dict["D->Y"] = 0.64090899
    coefs_dict["E->C"] = 0.78529674
    coefs_dict["E->E"] = 0.77100137
    coefs_dict["E->X"] = 0.3773349
    coefs_dict["X->B"] = 0.30968336
    coefs_dict["X->Y"] = 0.25
    coefs_dict["Y->Y"] = 0.23039564

    epsa = np.random.randn(T)
    epsb = np.random.randn(T)
    epsc = np.random.randn(T)
    epsd = np.random.randn(T)
    epse = np.random.randn(T)
    epsx = np.random.randn(T)
    epsy = np.random.randn(T)

    a = np.zeros([T])
    b = np.zeros([T])
    c = np.zeros([T])
    d = np.zeros([T])
    e = np.zeros([T])
    x = np.zeros([T])
    y = np.zeros([T])

    a[0] = 0.1 * epsa[0]
    d[0] = coefs_dict["A->D"] * a[0] + 0.1 * epsd[0]
    e[0] = 0.1 * epse[0]
    c[0] = coefs_dict["A->C"] * a[0] + coefs_dict["E->C"] * e[0] + 0.1 * epsc[0]
    x[0] = coefs_dict["E->X"] * e[0] + 0.1 * epsx[0]
    b[0] = coefs_dict["A->B"] * a[0] + coefs_dict["X->B"] * x[0] + 0.1 * epsb[0]
    y[0] = coefs_dict["X->Y"] * x[0] + coefs_dict["D->Y"] * d[0] + 0.1 * epsy[0]
    for i in range(1, T):
        a[i] = coefs_dict["A->A"] * a[i - 1] + 0.1 * epsa[i]
        d[i] = coefs_dict["D->D"] * d[i - 1] + coefs_dict["A->D"] * a[i] + 0.1 * epsd[i]
        e[i] = coefs_dict["E->E"] * e[i - 1] + coefs_dict["D->E"] * d[i-1] + 0.1 * epse[i]
        c[i] = coefs_dict["A->C"] * a[i] + coefs_dict["E->C"] * e[i] + 0.1 * epsc[i]
        x[i] = coefs_dict["A->X"] * a[i-1] + coefs_dict["E->X"] * e[i] + 0.1 * epsx[i]
        b[i] = coefs_dict["A->B"] * a[i] + coefs_dict["X->B"] * x[i] + 0.1 * epsb[i]
        y[i] = coefs_dict["X->Y"] * x[i] + coefs_dict["D->Y"] * d[i]  + coefs_dict["Y->Y"] * y[i-1] + 0.1 * epsy[i]

    a = pd.DataFrame(a, columns=["A"])
    b = pd.DataFrame(b, columns=["B"])
    c = pd.DataFrame(c, columns=["C"])
    d = pd.DataFrame(d, columns=["D"])
    e = pd.DataFrame(e, columns=["E"])
    x = pd.DataFrame(x, columns=["X"])
    y = pd.DataFrame(y, columns=["Y"])

    series = pd.concat([a, b, c, d, e, x, y], axis=1, sort=False)
    series = series.drop(series.index[[0,1]])
    series = series.reset_index(drop=True)
    series.index.names = ['time_index']
    return series


def get_temporal_nodes(g, gamma_max):
    nodes_to_temporal_nodes = dict()
    temporal_nodes = []
    for node in g.nodes:
        nodes_to_temporal_nodes[node] = []
        for gamma in range(2 * gamma_max + 1):
            if gamma == 0:
                temporal_node = str(node) + "_t"
                nodes_to_temporal_nodes[node].append(temporal_node)
                temporal_nodes.append(temporal_node)
            else:
                temporal_node = str(node) + "_t_" + str(gamma)
                nodes_to_temporal_nodes[node].append(temporal_node)
                temporal_nodes.append(temporal_node)
    return nodes_to_temporal_nodes


def time_series_to_windows(series, window_size, nodes_to_temporal_nodes):
    """
    Transforms a time series into a dataset of non-overlapping windows of fixed size, 
    capturing the `window_size` points before and including time t.

    Args:
        series (pd.DataFrame): Time series with variables as columns and time as rows.
        window_size (int): The number of time steps to include in each window, ending at t.
        nodes_to_temporal_nodes (dict): Mapping of column names to their temporal labels for each offset.

    Returns:
        pd.DataFrame: A DataFrame where each row corresponds to a non-overlapping window, 
                      with column names reflecting the temporal labels.
    """
    windows = []
    column_names = []

    # Generate column names for the window (from t-window_size+1 to t)
    for i in range(window_size):
        column_names.extend(
            [nodes_to_temporal_nodes[node][i] for node in series.columns]
        )

    # Extract non-overlapping windows of size `window_size` from the time series
    for start in range(window_size - 1, len(series), window_size):  # Step by `window_size`
        window_data = []
        for offset in range(window_size):  # Include data from t-window_size+1 to t
            window_data.extend(series.iloc[start - offset].values)
        windows.append(window_data)

    # Convert the list of windows into a DataFrame with appropriate column names
    windows_df = pd.DataFrame(windows, columns=column_names)
    return windows_df


def compute_adjustment_SCG(gamma_max, int_time, gamma_NC):
    ajustment_SCG = []
    gamma_NC_aux = {V: (gamma_V if gamma_V != -math.inf else -1) for V, gamma_V in gamma_NC.items()}

    for V, gamma_V in gamma_NC_aux.items():
        for gamma in range(gamma_V + 1, gamma_max + 1):
            if (V, gamma) != ('X', int_time):
                ajustment_SCG.append(f"{V}_t_{gamma}" if gamma > 0 else f"{V}_t") # Add temporal node to adjustment set
    
    return ajustment_SCG
 

if __name__ == '__main__':
    PRINT_RESULTS_MARKDOWN = True
    PLOT_RESULTS = True
    
    gamma_maxs =  list(range(3, 100, 10))
    print(gamma_maxs)
    window_size = max(gamma_maxs) + 1
    num_experiment = 100
    num_windows_per_experiment = 500
    T = window_size * num_experiment * num_windows_per_experiment
    intervention_time = 0
    

    g1 = nx.DiGraph()
    g1.add_edges_from(
        [('A', 'A'),
         ('A', 'C'),
         ('A', 'B'),
         ('A', 'D'),
         ('A', 'X'),
         ('D', 'D'),
         ('D', 'Y'),
         ('D', 'E'),
         ('E', 'C'),
         ('E', 'X'),
         ('E', 'E'),
         ('X', 'B'),
         ('X', 'Y'),
         ('Y', 'Y')])
    backdoor_FTCG = ['Y_t_1','D_t']
    nodes_to_temporal_nodes = get_temporal_nodes(g1, max(gamma_maxs))
    gamma_NC = {'A': -math.inf, 'C': -math.inf, 'B': -math.inf, 'D': -math.inf, 'X': -math.inf, 'Y': 0, 'E': -math.inf}

    
    #Generate data
    print(f"generating {T} data points...")
    data = data_generator(T)
    print("Recovering windows...")
    window_data = time_series_to_windows(data, window_size, nodes_to_temporal_nodes)
    
    #FTCG estimation
    list_estimation_FTCG = []
    for i in range(num_experiment):
        start_window = i * num_windows_per_experiment
        end_window = start_window + num_windows_per_experiment
        if start_window >= len(window_data):
            print("WARNING: not enough data for computing last group of windows")
            break  # Stop if we run out of windows
            # FTCG
        X_data_FTCG = window_data.iloc[start_window:end_window][['X_t'] + backdoor_FTCG]
        Y_data_FTCG = window_data.iloc[start_window:end_window][["Y_t"]]
        estimated_coef_SCG = list(lr().fit(X_data_FTCG, Y_data_FTCG).coef_)[0]
        list_estimation_FTCG.append(estimated_coef_SCG[0])
        
    mean_estimation_FTCG = np.mean(list_estimation_FTCG)
    std_estimation_FTCG = np.std(list_estimation_FTCG, ddof=1)
        
    #SCG_estimations
    adjustment_set_size = []
    mean_estimations = []
    std_estimations = []
    print("Estimating...")
    for gamma_max in gamma_maxs:
        ajustment_SCG = compute_adjustment_SCG(gamma_max, intervention_time, gamma_NC)
        adjustment_set_size.append(len(ajustment_SCG))
        
        list_estimation_SCG = []
        for i in range(num_experiment):
            start_window = i * num_windows_per_experiment
            end_window = start_window + num_windows_per_experiment
            if start_window >= len(window_data):
                print("WARNING: not enough data for computing last group of windows")
                break  # Stop if we run out of windows
            # SCG
            X_data_SCG = window_data.iloc[start_window:end_window][['X_t'] + ajustment_SCG]
            Y_data_SCG = window_data.iloc[start_window:end_window][["Y_t"]]
            estimated_coef_SCG = list(lr().fit(X_data_SCG, Y_data_SCG).coef_)[0]
            list_estimation_SCG.append(estimated_coef_SCG[0])
        
        mean_estimations.append(np.mean(list_estimation_SCG))
        std_estimations.append(np.std(list_estimation_SCG, ddof=1))
        #print('Gamma_max:', gamma_max, f"Mean= {np.mean(list_estimation_SCG):.4f} and Std= {np.std(list_estimation_SCG, ddof=1):.4f}")
    
    
    print("Estimation over!")
    print("Window Size:", window_size)
    print("Number of experiments:", num_experiment)
    print("Number of windows per experiments:", num_windows_per_experiment)
    print("Mean estimation FTCG;", mean_estimation_FTCG)
    print("STD estimation FTCG;", std_estimation_FTCG)
    
    if PRINT_RESULTS_MARKDOWN:
        # PRINTS RESULTS IN MARKDOWN
        print("Results in a markdown table:")
        
        # Find the max width for each column
        # Correct the LaTeX formatting for gamma_max
        latex_gamma_max = '$\\gamma_{\\text{max}}$'  # Proper LaTeX formatting for gamma_max
        
        max_gamma_width = max(max(len(str(g)) for g in gamma_maxs), len(latex_gamma_max))
        max_estimation_width = max(max(len(f"{e:.4f}") for e in mean_estimations), len('Estimation'))
        max_std_width = max(max(len(f"{s:.4f}") for s in std_estimations), len('std'))
        max_size_width = max(max(len(str(s)) for s in adjustment_set_size), len('Backdoor set size'))

        # Header for the table
        markdown_table = f"| {latex_gamma_max.ljust(max_gamma_width)} | {'Estimation'.ljust(max_estimation_width)} | {'std'.ljust(max_std_width)} | {'Adjustment set size'.ljust(max_size_width)} |\n"
        markdown_table += f"|:{'-' * max_gamma_width}:|:{'-' * max_estimation_width}:|:{'-' * max_std_width}:|:{'-' * max_size_width}:|\n"  # Table header with column alignments

        # Populate the table with values from the lists
        for gamma, estimation, std, size in zip(gamma_maxs, mean_estimations, std_estimations, adjustment_set_size):
            # Format each row and append to the markdown table string
            markdown_table += f"| {str(gamma).ljust(max_gamma_width)} | {f'{estimation:.4f}'.ljust(max_estimation_width)} | {f'{std:.4f}'.ljust(max_std_width)} | {str(size).ljust(max_size_width)} |\n"

        # Print the markdown table
        print(markdown_table)
        
    if PLOT_RESULTS:
        # Create figure and axes
        fig, ax1 = plt.subplots()

        # Primary x-axis for gamma_maxs
        ax1.set_xlabel(r"$\gamma_{\max}$", fontsize=14)
        ax1.set_xticks(gamma_maxs)  
        ax1.set_xticklabels(gamma_maxs)  
        ax1.set_xlim(0, max(gamma_maxs) + 4)  # Ensure proper scaling
        ax1.axhline(y=0.25, color='gray', linestyle='--', linewidth=1.5, label="True Value")


        # Secondary x-axis for adjustment set size
        ax2 = ax1.twiny()
        ax2.set_xlabel("Adjustment Set Size", fontsize=14, labelpad=10)
        ax2.set_xlim(ax1.get_xlim())  # Match scaling to gamma_maxs
        ax2.set_xticks(gamma_maxs)  # Use gamma_maxs positions
        ax2.set_xticklabels(adjustment_set_size)  # Display adjustment set size

        # Set font size for ticks
        ax1.tick_params(axis='both', which='major', labelsize=12)
        ax2.tick_params(axis='x', which='major', labelsize=12)

        # Primary y-axis for estimation with error bars
        ax1.errorbar(
            gamma_maxs,
            mean_estimations,
            yerr=std_estimations,
            fmt='o',          
            capsize=5,        
            color='teal',     
            ecolor='coral',    
            elinewidth=1.5,    
            markerfacecolor='teal',  
            label="Estimation"
        )
        ax1.set_ylabel("Estimation", fontsize=14)
        ax1.legend(fontsize=12, loc='best')

        # Grid
        ax1.grid(True, linestyle='--', alpha=0.7)

        # Save and show the plot
        plt.savefig('estimations.pdf', dpi=3000, format='pdf')
        plt.show()
