import numpy as np
from tigramite import data_processing as pp
from tigramite.pcmci import PCMCI
from tigramite.independence_tests.parcorr import ParCorr
from tigramite import plotting as tp
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pandas as pd
from glob import glob

def analyze_causal_graphs(trajectories):

    n_timepoints = len(trajectories[0].obs)
    n_states = 25

    state_data = np.zeros((n_timepoints, n_states))

    for traj in trajectories:
        if len(traj.obs) != n_timepoints:
            continue
        for t in range(n_timepoints):
            state = traj.obs[t]
            state_data[t, state] += 1

    dataframe = pp.DataFrame(state_data)

    pcmci = PCMCI(dataframe=dataframe, cond_ind_test=ParCorr())

    results = pcmci.run_pcmci(tau_max=2, pc_alpha=0.05)

    # tp.plot_graph(
    #     figsize=(9, 9),
    #     val_matrix=results['val_matrix'],
    #     graph=results['graph'],
    #     var_names=[str(i) for i in range(n_states)],
    #     link_colorbar_label='cross-MCI',
    #     node_colorbar_label='auto-MCI',
    #     arrow_linewidth=2,
    #     node_size=0.2
    # )
    # plt.show()
    max_causality_matrix_algo = np.max(results['val_matrix'], axis=2)
    #
    # plt.figure(figsize=(10, 8))
    # sns.heatmap(max_causality_matrix, annot=False, cmap="YlGnBu")
    # plt.title('Heatmap of Maximum Causality Values')
    # plt.xlabel('State')
    # plt.ylabel('State')
    # plt.show()
    max_causality_matrix_truth = compute_true_causality()

    difference_matrix = abs(max_causality_matrix_truth-max_causality_matrix_algo)
    # plt.figure(figsize=(10, 8))
    # sns.heatmap(difference_matrix, annot=False, cmap="YlGnBu")
    # plt.title('Heatmap of the Difference of Causality Values')
    # plt.xlabel('State')
    # plt.ylabel('State')
    # plt.show()

    location_to_id = {
        'Allogroom': 0, 'Carry pup': 1, 'Dig burrow': 2, 'Foraging': 3, 'Groom': 4,
        'High sitting/standing (Vigilant)': 5, 'Human Interaction': 6,
        'Interact with pup': 7, 'Interacting with foreign object': 8,
        'Low sitting/standing (stationary)': 9, 'Lying/resting (stationary)': 10,
        'Moving to Around Mound': 11, 'Moving to Background': 12, 'Moving to Door': 13,
        'Moving to Foraging': 14, 'Moving to Left Sticks Area': 15, 'Moving to Mound': 16,
        'Moving to Right Sand Area': 17, 'Moving to Right Sticks Area': 18,
        'Moving to Sand Area': 19, 'Moving to Waiting Area': 20, 'No action': 21,
        'Playfight': 22, 'Raised Guarding (Vigilant)': 23, 'Sunbathe': 24
    }

    # Create list of names ordered by location ID
    names = [name for name, id in sorted(location_to_id.items(), key=lambda item: item[1])]
    # df_difference_matrix = pd.DataFrame(difference_matrix, columns=names, index=names)
    # df_difference_matrix.to_csv(r'C:\Users\tanta\Desktop\Heatmap\AIRL_Meerkat.csv')
    df_difference_matrix = pd.DataFrame(max_causality_matrix_algo, columns=names, index=names)
    df_difference_matrix.to_csv(r'.\Heatmap\truth_AIRL.csv')

    plt.figure(figsize=(12, 10))  # Adjusted figure size for better fit
    ax = sns.heatmap(difference_matrix, annot=False, cmap="YlGnBu", xticklabels=names, yticklabels=names)
    plt.title('AIRL', fontsize=16)  # Increased fontsize for the title
    plt.xlabel('')
    plt.ylabel('')

    # Rotate x-axis labels for better visibility
    plt.xticks(rotation=45, ha='right')  # ha is the horizontal alignment
    plt.yticks(rotation=0)  # Make sure y-axis labels are not rotated for readability

    plt.tight_layout()  # Adjust layout to make room for label rotation
    plt.show()



def compute_true_causality():

    def read_csv(file_path):
        data = pd.read_csv(file_path, header=None, usecols=[2])
        return data

    def process_all_csv(root_folder):
        csv_files = glob(os.path.join(root_folder, '**/*.csv'), recursive=True)
        all_data = []

        for file_path in csv_files:
            data = read_csv(file_path)
            all_data.append(data)

        combined_data = pd.concat(all_data, ignore_index=True)

        unique_locations = sorted(combined_data.iloc[:, 0].unique())
        location_to_id = {location: idx for idx, location in enumerate(unique_locations)}

        combined_data['location_id'] = combined_data.iloc[:, 0].map(location_to_id)

        n_trajectories = len(combined_data) // 30
        processed_data = combined_data['location_id'].values.reshape((n_trajectories, 30))

        return processed_data

    root_folder = r'.\output3'
    all_combined_data = process_all_csv(root_folder)

    all_combined_data = all_combined_data.astype(np.float64)

    n_states = 25
    n_timepoints = 30
    n_trajectories = len(all_combined_data)
    state_data = np.zeros((n_timepoints, n_states))

    for state in range(n_states):
        for time in range(n_timepoints):
            state_data[time, state] = np.sum(all_combined_data[:, time] == state)

    dataframe = pp.DataFrame(state_data)

    pcmci = PCMCI(dataframe=dataframe, cond_ind_test=ParCorr())

    results = pcmci.run_pcmci(tau_max=2, pc_alpha=0.05)

    return np.max(results['val_matrix'], axis=2)