import os
import argparse
import numpy as np
import pandas as pd
from pathlib import Path
import pingouin as pg
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict

from utils import get_dataset_attributes

class AlignmentData:
    model_names = {
        "mamba": "Mamba-1.4B",
        "falcon3": "Falcon3-1B-Base",
        "llama3.2-1B": "Llama3.2-1B",
        "gemma": "Gemma-2B",
        "zamba": "Zamba2-1.2B",
    }

    def __init__(self, dataset, models, context_lengths, voxel_func, subj_func, root_dir):
        self.dataset = dataset
        self.models = models
        self.context_lengths = context_lengths
        self.voxel_func = voxel_func
        self.subj_func = subj_func
        self.root_dir = root_dir

        self.outputs_dir = root_dir / "outputs"
        os.makedirs(self.outputs_dir, exist_ok=True)
        self.plot_dir = root_dir / 'plots/alignment'
        for model in models:
            os.makedirs(self.plot_dir / self.model_names[model], exist_ok=True)
            os.makedirs(self.plot_dir / self.model_names[model] / 'heatmaps', exist_ok=True)
            os.makedirs(self.plot_dir / self.model_names[model] / 'boxplots', exist_ok=True)
            os.makedirs(self.plot_dir / self.model_names[model] / 'lineplots', exist_ok=True)

        self.subject_idxs, self.brain_areas, _, roi_file = get_dataset_attributes(dataset)
        self.roi_masks = np.load(root_dir / roi_file, allow_pickle=True)
        self.brain_areas = self.roi_masks.item()[self.subject_idxs[0]].keys()

        self.heatmap_dict = defaultdict(lambda: defaultdict(list)) # dict[brain_area][model] -> list of (n_context_lengths x n_layers) correlation scores
        self.df = None

    def load(self):
        # Load noise ceiling estimates and compute mean
        self.noise_ceilings, self.mean_noise_ceilings = [], []
        for subject_idx in self.subject_idxs:
            noise_ceiling = np.load(self.root_dir / f'data/{self.dataset}/noise_ceiling/subject_{subject_idx}_kernel_ridge.npy')
            self.noise_ceilings.append(noise_ceiling)
            self.mean_noise_ceilings.append(noise_ceiling.mean())
        
        # Load correlation scores
        data_df = []
        for brain_area in self.brain_areas:
            for model in self.models:
                nc_df = []
                for context_length in self.context_lengths:
                    if (model == 'gpt2' or model == 'gpt2-2') and context_length > 1024:
                        continue
                    subjects_corrs = []
                    for s, subject_idx in enumerate(self.subject_idxs):
                        # Path to file containing the subject's correlation scores
                        experiment_dir = self.dataset + '_' + model + '_' + str(context_length)
                        corr_scores_file = 'correlation_scores' + '_' + subject_idx + '.npy'
                        print(self.outputs_dir / experiment_dir / corr_scores_file)
                        assert os.path.isfile(self.outputs_dir / experiment_dir / corr_scores_file), \
                            "Specify a valid path in args.experiment.corr_scores_file"
                        
                        # Load correlation scores
                        correlation_scores = np.load(self.outputs_dir / experiment_dir / corr_scores_file) # (n_layers, n_voxels)
                        # Normalize by noise ceiling
                        #correlation_scores = correlation_scores / np.expand_dims(self.noise_ceilings[s], axis=0)

                        # Aggregate voxel-wise correlation scores by mean/median
                        mask = self.roi_masks.item()[subject_idx][brain_area]
                        mask = np.logical_and(mask, self.noise_ceilings[s] >= 0.05)
                        nc_df.append({
                            'Model': self.model_names[model],
                            'Subject': subject_idx,
                            'ContextLength': context_length,
                            'Percentage': np.sum(correlation_scores[:,mask] > self.noise_ceilings[s][mask], axis=1).mean() / mask.sum()
                        })
                        mask = np.repeat(np.expand_dims(mask, axis=0), correlation_scores.shape[0], axis=0)
                        #print(f"Model {model}, brain area {brain_area}, context_length {context_length}, subj {subject_idx} - % corr > noise ceiling: ", np.sum(correlation_scores*mask > noise_ceiling) / correlation_scores.size)
                        if mask.shape != correlation_scores.shape:
                            print(f"Model {model}, brain area {brain_area}, context_length {context_length}, subj {subject_idx}: Shapes not matching! ", correlation_scores.shape, mask.shape)
                            continue
                        masked_corrs = np.ma.array(correlation_scores, mask=mask)
                        if self.voxel_func == 'mean':
                            subjects_corrs.append(np.ma.mean(masked_corrs, axis=1))
                        else:
                            if self.voxel_func != 'median':
                                print("Aggregation function for voxel values not supported. Using median.")
                            masked_corrs_filled = masked_corrs.filled(np.nan)
                            subjects_corrs.append(np.nanmedian(masked_corrs_filled, axis=1))

                        for layer_idx, layer_corr in enumerate(subjects_corrs[-1]):
                            data_df.append({
                                'BrainArea': brain_area,
                                'Model': self.model_names[model],
                                'Subject': subject_idx,
                                'ContextLength': context_length,
                                'LayerIndex': layer_idx,
                                'Correlation': layer_corr
                            })

                    subjects_corrs = np.stack(subjects_corrs, axis=0) # (n_subjects, n_layers)

                    # Aggregate subject's correlation scores by mean/median
                    if self.subj_func == 'mean':
                        layer_corrs = np.mean(subjects_corrs, axis=0) # (n_layers,)
                    else:
                        if self.subj_func != 'median':
                            print("Aggregation function for subject values not supported. Using median.")
                        layer_corrs = np.median(subjects_corrs, axis=0) # (n_layers,)

                    self.heatmap_dict[brain_area][model].append(layer_corrs.tolist())

                # Create a DataFrame from nc_df and plot a barplot
                nc_df = pd.DataFrame(nc_df)
                plt.figure(figsize=(12, 8))
                sns.set_style("whitegrid")

                # Barplot with error bars
                sns.barplot(
                    x='ContextLength',
                    y='Percentage',
                    data=nc_df,
                    estimator=np.mean,
                    errorbar='se',  # Standard deviation as error bars
                    capsize=0.1
                )

                # Label axes and title
                plt.title(f"Percentage of Correlations > Noise Ceiling for {brain_area}")
                plt.xlabel("Context Length")
                plt.ylabel("Percentage")
                plt.tight_layout()

                # Save the plot
                plt.savefig(self.plot_dir / self.model_names[model] / f"barplot_nc_{brain_area}.png")
                plt.close()

        self.df = pd.DataFrame(data_df)
        print("Data loaded successfully.")
    
    def perform_significance_tests_across_models(self):
        results_anova = []     # to store rm_anova results
        results_friedman = []  # to store friedman results
        results_posthoc = []   # to store pairwise post hoc results

        # Filter out rows where BrainArea is 'all'
        filtered_df = self.df[self.df['BrainArea'] != 'all']

        # Group by (ContextLength, LayerIndex) so we test each combination separately
        for (context_length, layer), df_group in filtered_df.groupby(['ContextLength', 'LayerIndex']):
            # Substitute the column "Correlation" in df_group with the mean across "BrainArea"
            df_group['Correlation'] = df_group.groupby(['Subject', 'Model', 'ContextLength', 'LayerIndex'])['Correlation'].transform('mean')

            # If there's only one Model for this group, skip
            if df_group['Model'].nunique() < 2:
                continue

            # --- 1) Repeated-measures ANOVA ---
            # dv='Correlation': the dependent variable is the correlation
            # within='Model': repeated factor
            # subject='Subject': repeated measure subject ID
            anova_res = pg.rm_anova(
                data=df_group,
                dv='Correlation',
                within='Model',
                subject='Subject',
                correction=False
            )

            # Keep track of the grouping info
            anova_res['ContextLength'] = context_length
            anova_res['Layer']         = layer
            results_anova.append(anova_res)

            friedman_res = pg.friedman(data=df_group, dv='Correlation', within='Model', subject='Subject')
            friedman_res['ContextLength'] = context_length
            friedman_res['Layer']     = layer
            results_friedman.append(friedman_res)

            # --- 2) Post hoc pairwise comparisons ---
            # This will test each pair of ContextLength values
            ph_res = pg.pairwise_tests(
                data=df_group,
                dv='Correlation',
                within='Model',
                subject='Subject',
                padjust='fdr_bh',
                effsize='r'      # effect size
            )

            # Keep track of the grouping info
            ph_res['ContextLength'] = context_length
            ph_res['Layer']     = layer
            results_posthoc.append(ph_res)

        # Concatenate results into dataframes for inspection
        df_anova_results   = pd.concat(results_anova, ignore_index=True)
        df_friedman_results = pd.concat(results_friedman, ignore_index=True)
        df_posthoc_results = pd.concat(results_posthoc, ignore_index=True)

        # Save significant results to a text file
        with open(self.plot_dir / 'significance_tests_results.txt', 'w') as f:
            f.write("=== ANOVA Results ===\n")
            f.write(df_anova_results[df_anova_results['p-unc']<0.05].to_string(index=False))
            f.write("\n\n=== Friedman Results ===\n")
            f.write(df_friedman_results[df_friedman_results['p-unc']<0.05].to_string(index=False))
            f.write("\n\n=== Post Hoc (Pairwise tests) Results ===\n")
            f.write(df_posthoc_results[df_posthoc_results['p-unc']<0.05].to_string(index=False))

    def perform_significance_tests(self):
        results_friedman = []  # to store friedman results
        results_posthoc = []   # to store pairwise post hoc results

        # Filter out rows where BrainArea is 'all'
        filtered_df = self.df[self.df['BrainArea'] != 'all'].copy()
        filtered_df = filtered_df.groupby(['BrainArea', 'ContextLength', 'Model', 'LayerIndex', 'Subject'], as_index=False)['Correlation'].mean() # Compute mean correlation within each ROI
        filtered_df = filtered_df.groupby(['ContextLength', 'LayerIndex', 'Model', 'Subject'], as_index=False)['Correlation'].mean() # Compute mean correlation across all ROIs

        # For each model, group layers into three categories: Early, Middle, Final
        filtered_df['Depth'] = filtered_df.groupby('Model')['LayerIndex'].transform(
            lambda x: pd.cut(
            x,
            bins=[-1, x.max() // 3, 2 * x.max() // 3, x.max()],
            labels=['Early Layers', 'Middle Layers', 'Final Layers']
            )
        )
        filtered_df = filtered_df.groupby(['ContextLength', 'Depth', 'Model', 'Subject'], as_index=False)['Correlation'].mean()

        # Perform significance tests for each model and depth
        for model in filtered_df['Model'].unique():
            df_model = filtered_df[filtered_df['Model'] == model]
            for depth in df_model['Depth'].unique():
                df_depth = df_model[df_model['Depth'] == depth]

                # Perform Friedman test across context lengths
                if df_depth['ContextLength'].nunique() > 1:  # Ensure multiple context lengths are present
                    friedman_res = pg.friedman(
                    data=df_depth,
                    dv='Correlation',
                    within='ContextLength',
                    subject='Subject',
                    )
                    friedman_res['Model'] = model
                    friedman_res['Depth'] = depth
                    results_friedman.append(friedman_res)
                    print(f"Friedman test results for model {model}, depth {depth}:", friedman_res)

                    # If significant, perform post hoc pairwise comparisons
                    if friedman_res['p-unc'][0] < 0.05:
                        posthoc_res = pg.pairwise_tests(
                            data=df_depth,
                            dv='Correlation',
                            within='ContextLength',
                            subject='Subject',
                            padjust='fdr_bh',
                            effsize='r'
                        )
                        posthoc_res['Model'] = model
                        posthoc_res['Depth'] = depth
                        results_posthoc.append(posthoc_res)
                        print(f"Post hoc pairwise results for model {model}, depth {depth}:", posthoc_res)

        # Concatenate results into dataframes for inspection
        df_friedman_results = pd.concat(results_friedman, ignore_index=True)
        df_posthoc_results = pd.concat(results_posthoc, ignore_index=True)

        # Save results to CSV files
        df_friedman_results.to_csv(self.plot_dir / 'friedman_results_cl.csv', index=False)
        df_posthoc_results.to_csv(self.plot_dir / 'posthoc_results_cl.csv', index=False)

    def plot_layerwise_mean_corr_heathmap(self):
        ### Plot heatmap with single model's layer-wise correlations using matplotlib
        for brain_area in self.brain_areas:
            for model in self.models:
                context_lengths_str = [str(c) for c in self.context_lengths if model != 'gpt2' or c < 1024]
                heatmap_values = self.heatmap_dict[brain_area][model]
                plt.figure(figsize=(30, 20))
                sns.heatmap(heatmap_values, annot=True, fmt=".3f", cmap='Blues', xticklabels=list(range(len(heatmap_values[0]))), yticklabels=context_lengths_str)
                model_name = self.model_names[model] if model in self.model_names.keys() else model
                plt.title(f'{model_name} - {brain_area}', fontsize=20)
                plt.xlabel('Layer', fontsize=16)
                plt.ylabel('Context length', fontsize=16)
                plt.xticks(fontsize=14)
                plt.yticks(fontsize=14)
                plt.tight_layout()
                plt.savefig(self.plot_dir / self.model_names[model] / f'heatmaps/heathmap_model_{model}_roi_{brain_area}.png', dpi=800)
                plt.close()
    
    def plot_layerwise_boxplots(self):
        # For each combination of layer and ROI, we'll create one boxplot figure
        for (model, layer, brain_area), df_group in self.df.groupby(['Model', 'LayerIndex', 'BrainArea']):
            df_group['ContextLength'] = df_group['ContextLength'].astype(str)
            plt.figure(figsize=(12, 8))
            sns.set_style("whitegrid")

            # Loop over each ContextLength, so each gets its own box
            sns.boxplot(
                x='ContextLength',
                y='Correlation',
                hue='ContextLength',
                data=df_group,
                palette="Blues",
                showmeans=True
            )

            # Label axes and title
            plt.title(f"Model: {model}, Layer: {layer}, ROI: {brain_area}")
            plt.xlabel("Context Length")
            plt.ylabel("Correlation")
            plt.tight_layout()
            plt.savefig(self.plot_dir / model / f"boxplots/boxplot_layer_{layer}_roi_{brain_area}.png")
            plt.close()
    
    def plot_layers_lineplot(self):
        fig = plt.figure(figsize=(12, 6))
        sns.set_style("whitegrid")

        # Mean correlation across layers for all models
        df_all_models = self.df[self.df['BrainArea'] != 'all'].copy()
        df_all_models = df_all_models.groupby(['BrainArea', 'ContextLength', 'Model', 'LayerIndex', 'Subject'], as_index=False)['Correlation'].mean() # Compute mean correlation within each ROI
        df_all_models = df_all_models.groupby(['ContextLength', 'LayerIndex', 'Model', 'Subject'], as_index=False)['Correlation'].mean() # Compute mean correlation across all ROIs
        max_layer = df_all_models['LayerIndex'].max()
        df_all_models['Depth'] = df_all_models.groupby('Model')['LayerIndex'].transform(
            lambda x: pd.cut(
            x,
            bins=[-1, x.max() // 3, 2 * x.max() // 3, x.max()],
            labels=['Early Layers', 'Middle Layers', 'Late Layers']
            )
        )
        df_all_models_mean = df_all_models.groupby(['ContextLength', 'Depth', 'Model'], as_index=False)['Correlation'].mean()
        df_all_models_mean['ContextLength'] = [self.context_lengths.index(c) for c in df_all_models_mean['ContextLength']]

        palette = sns.color_palette("Set1", 3)
        sns.lineplot(
            x='ContextLength',
            y='Correlation',
            hue='Depth',
            data=df_all_models_mean,
            palette=palette,  # Use three predefined pastel colors
            estimator=np.mean,  # Aggregate data by mean
            errorbar='se',      # Show standard error as error bars
            markers=True,       # Use markers for each point
        )

        # Perform significance tests
        significance_results = []
        for context_length in df_all_models['ContextLength'].unique():
            df_context = df_all_models[df_all_models['ContextLength'] == context_length]
            if df_context['Depth'].nunique() == 3:  # Ensure all three depths are present
                friedman_res = pg.friedman(
                    data=df_context,
                    dv='Correlation',
                    within='Depth',
                    subject='Subject',
                )
                if friedman_res['p-unc'][0] < 0.05:
                    posthoc_res = pg.pairwise_tests(
                    data=df_context,
                    dv='Correlation',
                    within='Depth',
                    subject='Subject',
                    padjust='fdr_bh',
                    effsize='r'
                    )
                    significance_results.append((context_length, posthoc_res))
                    print(f"Significant results for context length {context_length}:", posthoc_res)

        # Add significance markers
        ax = plt.gca()
        for context_length, posthoc_res in significance_results:
            x_pos = self.context_lengths.index(context_length)  # Ensure x_pos aligns with x-axis ticks
            count = 0
            for _, row in posthoc_res.iterrows():
                if ((row['A'] == 'Middle Layers' and row['T'] >= 0) or (row['B'] == 'Middle Layers' and row['T'] < 0)) and row['p-corr'] < 0.05:
                    if row['p-corr'] < 0.001 and count < 3:
                        count = 3
                    elif row['p-corr'] < 0.01 and count < 2:
                        count = 2
                    elif count < 1:
                        count = 1
            if count > 0:
                ax.text(x_pos, 0.08, '*'*count, ha='center', fontsize=20, color=palette[1], fontweight='extra bold')

            

        # Label axes and title
        current_ylim = ax.get_ylim()
        ax.set_ylim(current_ylim[0], 0.085)
        ax.set_xlabel("Context Length", fontsize=20)
        ax.set_ylabel("Mean Correlation", fontsize=20)
        ax.tick_params(axis='y', labelsize=16)
        ax.set_xticks(list(range(len(self.context_lengths))))
        ax.set_xticklabels(self.context_lengths, fontsize=18)

        # Add legend item for significance marker
        handles, labels = ax.get_legend_handles_labels()
        handles.append(plt.Line2D([0], [0], color=palette[1], marker='*', linestyle='None', markersize=10, label='Significantly higher'))
        labels.append('Sig. Higher')

        # Clear any existing legend
        ax.legend_.remove()

        # Add top-center legend above plot
        fig.legend(
            handles=handles,
            labels=labels,
            title=None,
            fontsize=18,
            loc='upper center',
            bbox_to_anchor=(0.5, 1.15),
            bbox_transform=ax.transAxes,
            ncol=len(labels),
        )

        # Adjust layout manually — enough top space for the legend
        plt.subplots_adjust(left=0.10, right=0.98, top=0.88, bottom=0.12)

        fig.savefig(self.plot_dir / "average_layers_lineplot.pdf", dpi=1200)
        fig.savefig(self.plot_dir / "average_layers_lineplot.png")
        plt.close(fig)

    
    def plot_layers_lineplot_by_model_family(self):
        # Create a mapping of models to their families
        model_families = {
            'Mamba-1.4B': 'SSMs',
            'Zamba2-1.2B': 'Hybrid',
            'Gemma-2B': 'Transformers',
            'Llama3.2-1B': 'Transformers',
            'Falcon3-1B-Base': 'Transformers',
        }

        df_all_models = self.df[self.df['BrainArea'] != 'all'].copy()

        # Compute mean correlation within each ROI
        df_all_models = df_all_models.groupby(['BrainArea', 'ContextLength', 'Model', 'LayerIndex', 'Subject'], as_index=False)['Correlation'].mean()
        # Compute mean correlation across all ROIs
        df_all_models = df_all_models.groupby(['ContextLength', 'LayerIndex', 'Model', 'Subject'], as_index=False)['Correlation'].mean()

        # Group layers into Early, Middle, Final for each model
        df_all_models['Depth'] = df_all_models.groupby('Model')['LayerIndex'].transform(
            lambda x: pd.cut(
            x,
            bins=[-1, x.max() // 3, 2 * x.max() // 3, x.max()],
            labels=['Early Layers', 'Middle Layers', 'Final Layers']
            )
        )

        # Compute mean correlation for each model and depth
        df_all_models_mean = df_all_models.groupby(['ContextLength', 'Depth', 'Model'], as_index=False)['Correlation'].mean()
        df_all_models_mean['ContextLength'] = [self.context_lengths.index(c) for c in df_all_models_mean['ContextLength']]

        # Group by model family (transformer, ssm, hybrid)
        df_all_models_mean['Family'] = df_all_models_mean['Model'].map(model_families)
        df_all_models_mean.drop(columns=['Model'], inplace=True)

        # Define a color palette for families and darker shades for depths
        family_palettes = [sns.color_palette("Blues_d", 3), 
                           sns.color_palette("Greens_d", 3),
                           sns.color_palette("Reds_d", 3)]

        # Plot lines for each family and depth
        plt.figure(figsize=(12, 6))
        sns.set_style("whitegrid")
        for family_idx, family in enumerate(df_all_models_mean['Family'].unique()):
            for depth_idx, depth in enumerate(df_all_models_mean['Depth'].unique()):
                df_family_depth = df_all_models_mean[(df_all_models_mean['Family'] == family) & (df_all_models_mean['Depth'] == depth)]
                sns.lineplot(
                    x='ContextLength',
                    y='Correlation',
                    data=df_family_depth,
                    color=family_palettes[family_idx][depth_idx],
                    label=f"{family}" if depth_idx == 2 else None,
                    estimator=np.mean,
                    errorbar='se',
                    markers=True
                )
        
        # Perform significance tests
        df_all_models['Family'] = df_all_models['Model'].map(model_families)
        df_all_models.drop(columns=['Model'], inplace=True)
        family_to_idx = {family: idx for idx, family in enumerate(df_all_models['Family'].unique())}
        significance_results = []
        ax = plt.gca()
        for depth_idx, depth in enumerate(df_all_models['Depth'].unique()):
            for context_length in df_all_models['ContextLength'].unique():
                df_subset = df_all_models[(df_all_models['Depth'] == depth) & (df_all_models['ContextLength'] == context_length)]
                if df_subset['Family'].nunique() > 1:  # Ensure multiple families are present
                    # Perform pairwise Wilcoxon tests across families
                    wilcoxon_res = pg.pairwise_tests(df_subset, dv='Correlation', within='Family', subject='Subject', padjust='fdr_bh')
                    wilcoxon_res['ContextLength'] = context_length
                    wilcoxon_res['Depth'] = depth
                    print(wilcoxon_res)
                    significance_results.append(wilcoxon_res)

                    x_pos = self.context_lengths.index(context_length)  # Ensure x_pos aligns with x-axis ticks
                    count = 1
                    family_count = defaultdict(lambda: defaultdict(int))
                    for _, row in wilcoxon_res.iterrows():
                        if row['p-corr'] < 0.05:
                            if row['T'] >= 0:
                                family_idx = family_to_idx[row['A']]
                                other_family = row['B']
                            else:
                                family_idx = family_to_idx[row['B']]
                                other_family = row['A']
                            if row['p-corr'] < 0.001 and count < 3:
                                count = 3
                            elif row['p-corr'] < 0.01 and count < 2:
                                count = 2
                            family_count[family_idx][other_family] += count
                            
                    for family_idx in range(len(family_to_idx.keys())):
                        if len(family_count[family_idx]) == 2: # Higher than all other families
                            count = max(family_count[family_idx].values())
                            align = 'right' if context_length == 40 else 'center'
                            ax.text(x_pos, 0.085 + (0.0035*depth_idx), '*'*count, ha=align, fontsize=20, color=family_palettes[family_idx][depth_idx], fontweight='extra bold')
                        elif len(family_count[family_idx]) == 1: # Lower than all other families
                            key, count = list(family_count[family_idx].items())[0]
                            align = 'left' if context_length == 40 else 'center'
                            if context_length == 1:
                                ax.text(x_pos, 0.085 + (0.0035*depth_idx), '*'*count, ha='center', fontsize=20, color=family_palettes[family_idx][depth_idx], fontweight='extra bold')
                                ax.text(x_pos, 0.085 + (0.0035*depth_idx), '    '+key[0], ha=align, fontsize=20, color=family_palettes[family_idx][depth_idx], fontweight='extra bold')
                            else:
                                ax.text(x_pos, 0.085 + (0.0035*depth_idx), '*'*count+key[0], ha=align, fontsize=20, color=family_palettes[family_idx][depth_idx], fontweight='extra bold')

        # Combine results into a DataFrame and save to CSV
        if significance_results:
            df_significance_results = pd.concat(significance_results, ignore_index=True)
            df_significance_results.to_csv(self.plot_dir / 'wilcoxon_results_by_family.csv', index=False)

        # Label axes and title
        current_ylim = ax.get_ylim()
        ax.set_ylim(current_ylim[0], 0.1)
        plt.xlabel("Context Length", fontsize=20)
        plt.ylabel("Mean Correlation", fontsize=20)
        plt.yticks(fontsize=16)
        plt.xticks(ticks=list(range(len(self.context_lengths))), labels=self.context_lengths, fontsize=18)
        plt.legend(title=None, fontsize=14)  # Remove legend title and increase font size
        plt.tight_layout()
        plt.savefig(self.plot_dir / "average_layers_lineplot_by_family.png")
        plt.close()
    
    def plot_layers_barplot_by_model_family(self):
        # Create a mapping of models to their families
        model_families = {
            'Mamba-1.4B': 'SSMs',
            'Zamba2-1.2B': 'Hybrid',
            'Gemma-2B': 'Transformers',
            'Llama3.2-1B': 'Transformers',
            'Falcon3-1B-Base': 'Transformers',
        }

        df_all_models = self.df[self.df['BrainArea'] != 'all'].copy()

        # Compute mean correlation within each ROI
        df_all_models = df_all_models.groupby(['BrainArea', 'ContextLength', 'Model', 'LayerIndex', 'Subject'], as_index=False)['Correlation'].mean()
        # Compute mean correlation across all ROIs
        df_all_models = df_all_models.groupby(['ContextLength', 'LayerIndex', 'Model', 'Subject'], as_index=False)['Correlation'].mean()

        # Group layers into Early, Middle, Final for each model
        df_all_models['Depth'] = df_all_models.groupby('Model')['LayerIndex'].transform(
            lambda x: pd.cut(
            x,
            bins=[-1, x.max() // 3, 2 * x.max() // 3, x.max()],
            labels=['Early Layers', 'Middle Layers', 'Late Layers']
            )
        )

        # Compute mean across context lengths
        df_1 = df_all_models[df_all_models['ContextLength'] == 1].copy()
        df_other_avg = df_all_models[df_all_models['ContextLength'] != 1].copy()
        df_other_avg = df_other_avg.groupby(['Depth', 'Model', 'Subject'], as_index=False)['Correlation'].mean()

        # Group by model family (transformer, ssm, hybrid)
        df_1['Family'] = df_1['Model'].map(model_families)
        df_1.drop(columns=['Model'], inplace=True)
        df_other_avg['Family'] = df_other_avg['Model'].map(model_families)
        df_other_avg.drop(columns=['Model'], inplace=True)

        palette = sns.color_palette("Set1", 3)

        # Create a figure with two subplots
        fig, axes = plt.subplots(1, 2, figsize=(15, 6), sharey=True)
        sns.set_style("whitegrid")

        # Plot the first barplot
        sns.barplot(
            x='Depth',
            y='Correlation',
            hue='Family',
            data=df_1,
            palette=palette,  # Use three predefined pastel colors
            estimator=np.mean,
            errorbar='se',
            ax=axes[0]
        )
        axes[0].set_title("Context Length = 1", fontsize=24)
        axes[0].set_xlabel("")
        axes[0].set_ylabel("Mean Correlation", fontsize=22)
        axes[0].tick_params(axis='y', labelsize=20)
        axes[0].tick_params(axis='x', labelsize=20)
        axes[0].legend(title=None, fontsize=20)

        # Perform significance tests for the first barplot
        significance_results = []
        bars = [patch for patch in axes[0].patches if isinstance(patch, plt.Rectangle) and patch.get_height() > 0]
        bars = sorted(bars, key=lambda bar: bar.get_x())
        bars_count = [0 for _ in range(len(bars))]
        family_to_idx = {family: idx for idx, family in enumerate(df_1['Family'].unique())}
        for depth_idx, depth in enumerate(df_1['Depth'].unique()):
            df_subset = df_1[df_1['Depth'] == depth]
            if df_subset['Family'].nunique() > 1:  # Ensure multiple families are present
                # Perform pairwise Wilcoxon tests across families
                wilcoxon_res = pg.pairwise_tests(df_subset, dv='Correlation', within='Family', subject='Subject', padjust='fdr_bh')
                wilcoxon_res['Depth'] = depth
                print(wilcoxon_res)
                significance_results.append(wilcoxon_res)
                for _, row in wilcoxon_res.iterrows():
                    if row['p-corr'] < 0.05:
                        if row['T'] >= 0:
                            color = palette[family_to_idx[row['A']]]
                            bar_idx = depth_idx*len(family_to_idx.keys()) + family_to_idx[row['B']]
                        else:
                            color = palette[family_to_idx[row['B']]]
                            bar_idx = depth_idx*len(family_to_idx.keys()) + family_to_idx[row['A']]
                        count = 1
                        if row['p-corr'] < 0.001 and count < 3:
                            count = 3
                        elif row['p-corr'] < 0.01 and count < 2:
                            count = 2
                        x_coord = bars[bar_idx].get_x() + bars[bar_idx].get_width() / 2
                        print(f"Significant result: {row['A']} vs {row['B']}, p-corr: {row['p-corr']}, x_pos: {x_coord}, count: {count}")
                        axes[0].text(x_coord, 0.085 + (0.0035 * bars_count[bar_idx]), '*' * count, ha='center', fontsize=20, color=color, fontweight='extra bold')
                        bars_count[bar_idx] += 1

        # Plot the second barplot
        sns.barplot(
            x='Depth',
            y='Correlation',
            hue='Family',
            data=df_other_avg,
            palette=palette,  # Use three predefined pastel colors
            estimator=np.mean,
            errorbar='se',
            ax=axes[1]
        )
        axes[1].set_title("Average across context lengths > 1", fontsize=24)
        axes[1].set_xlabel("")
        axes[1].tick_params(axis='x', labelsize=20)
        axes[1].legend(title=None, fontsize=20)

        # Perform significance tests for the second barplot
        significance_results = []
        bars = [patch for patch in axes[1].patches if isinstance(patch, plt.Rectangle) and patch.get_height() > 0]
        bars = sorted(bars, key=lambda bar: bar.get_x())
        bars_count = [0 for _ in range(len(bars))]
        family_to_idx = {family: idx for idx, family in enumerate(df_other_avg['Family'].unique())}
        for depth_idx, depth in enumerate(df_other_avg['Depth'].unique()):
            df_subset = df_other_avg[df_other_avg['Depth'] == depth]
            if df_subset['Family'].nunique() > 1:  # Ensure multiple families are present
                # Perform pairwise Wilcoxon tests across families
                wilcoxon_res = pg.pairwise_tests(df_subset, dv='Correlation', within='Family', subject='Subject', padjust='fdr_bh')
                wilcoxon_res['Depth'] = depth
                print(wilcoxon_res)
                significance_results.append(wilcoxon_res)
                for _, row in wilcoxon_res.iterrows():
                    if row['p-corr'] < 0.05:
                        if row['T'] >= 0:
                            color = palette[family_to_idx[row['A']]]
                            bar_idx = depth_idx*len(family_to_idx.keys()) + family_to_idx[row['B']]
                        else:
                            color = palette[family_to_idx[row['B']]]
                            bar_idx = depth_idx*len(family_to_idx.keys()) + family_to_idx[row['A']]
                        count = 1
                        if row['p-corr'] < 0.001 and count < 3:
                            count = 3
                        elif row['p-corr'] < 0.01 and count < 2:
                            count = 2
                        x_coord = bars[bar_idx].get_x() + bars[bar_idx].get_width() / 2
                        print(f"Significant result: {row['A']} vs {row['B']}, p-corr: {row['p-corr']}, x_pos: {x_coord}, count: {count}")
                        axes[1].text(x_coord, 0.085 + (0.0035 * bars_count[bar_idx]), '*' * count, ha='center', fontsize=20, color=color, fontweight='extra bold')
                        bars_count[bar_idx] += 1

        # Adjust layout and save the combined plot
        current_ylim = axes[1].get_ylim()
        axes[1].set_ylim(current_ylim[0], 0.1)
        current_ylim = axes[0].get_ylim()
        axes[0].set_ylim(current_ylim[0], 0.1)
        plt.tight_layout()
        plt.savefig(self.plot_dir / "average_layers_barplot_by_family_combined.png")
        plt.close()
    
    def plot_aggregated_lineplot(self):
        df_all = self.df[self.df['BrainArea'] != 'all'].copy()
        df_all = df_all.groupby(['BrainArea', 'ContextLength', 'Model', 'LayerIndex', 'Subject'], as_index=False)['Correlation'].mean() # Compute mean correlation within each ROI
        df_all = df_all.groupby(['ContextLength', 'Model', 'Subject'], as_index=False)['Correlation'].mean() # Compute mean correlation across all ROIs
        df_all = df_all.groupby(['Model', 'ContextLength', 'Subject'])['Correlation'].mean().reset_index() # Compute mean correlation across all layers
        df_all['ContextLength'] = [self.context_lengths.index(c) for c in df_all['ContextLength']]
        plt.figure(figsize=(12, 8))
        sns.set_style("whitegrid")
        palette = sns.color_palette("Set1", 5)
        sns.lineplot(
            x='ContextLength',
            y='Correlation',
            hue='Model',
            data=df_all,
            palette=palette,
            estimator=np.mean,
            errorbar=('se', 0.95),
            err_style='bars',
            markers=True
        )

        # Perform significance tests to identify best/worst models at each context length
        significance_results = []
        for context_length in self.context_lengths:
            df_context = df_all[df_all['ContextLength'] == self.context_lengths.index(context_length)]
            if df_context['Model'].nunique() > 1:  # Ensure multiple models are present
                friedman_res = pg.friedman(
                    data=df_context,
                    dv='Correlation',
                    within='Model',
                    subject='Subject',
                )
                if friedman_res['p-unc'][0] < 0.05:
                    posthoc_res = pg.pairwise_tests(
                    data=df_context,
                    dv='Correlation',
                    within='Model',
                    subject='Subject',
                    padjust='fdr_bh',
                    effsize='r'
                    )
                    significance_results.append((context_length, posthoc_res))
                    print(f"Significant results for context length {context_length}:", posthoc_res)

        # Add significance markers for best/worst models
        ax = plt.gca()
        for context_length, posthoc_res in significance_results:
            seen_models = {}
            for _, row in posthoc_res.iterrows():
                if row['p-unc'] < 0.05:
                    x_pos = self.context_lengths.index(context_length)
                    colorA = palette[df_all['Model'].unique().tolist().index(row['A'])]
                    colorB = palette[df_all['Model'].unique().tolist().index(row['B'])]
                    fontweight = 'extra bold' if row['p-unc'] < 0.005 else 'normal'
                    if row['T'] > 0:  # Model A is better
                        if row['A'] not in seen_models.keys():
                            seen_models[row['A']] = 0
                        else:
                            seen_models[row['A']] += 1
                        model = row['A']
                        ax.text(x_pos, 0.073, '*', ha='left', fontsize=12, color=colorA, fontweight=fontweight)
                        ax.text(x_pos, 0.073, '*', ha='right', fontsize=12, color=colorB, fontweight=fontweight)
                    else:  # Model B is better
                        if row['B'] not in seen_models.keys():
                            seen_models[row['B']] = 0
                        else:
                            seen_models[row['B']] += 1
                        model = row['B']
                        ax.text(x_pos, 0.073 + (seen_models[model] * 0.002), '*', ha='right', fontsize=12, color=colorA, fontweight=fontweight)
                        ax.text(x_pos, 0.073 + (seen_models[model] * 0.002), '*', ha='left', fontsize=12, color=colorB, fontweight=fontweight)
        
        # Label axes and title
        plt.title(f"Mean Correlation across layers, voxels, and subjects")
        plt.xticks(ticks=list(range(len(self.context_lengths))), labels=self.context_lengths)
        plt.xlabel("Context Length")
        plt.ylabel("Mean Correlation")
        plt.legend(title='Model')
        plt.tight_layout()
        plt.savefig(self.plot_dir / "aggregated_lineplot.png")
        plt.close()
    
    def plot_modelwise_barplot(self):
        # Filter data for context length 640
        df_all = self.df[(self.df['BrainArea'] == 'all') & (self.df['ContextLength'] == 640)].copy()
        df_all = df_all.groupby(['Model', 'Subject'], as_index=False)['Correlation'].mean()  # Mean correlation across layers, voxels, and subjects

        # Define a color palette for the models, avoiding pastel red, green, and orange
        unique_models = df_all['Model'].unique()
        palette = sns.color_palette("Set1", len(unique_models))

        # Barplot with error bars
        plt.figure(figsize=(8, 6))
        sns.set_style("whitegrid")
        sns.barplot(
            x='Model',
            y='Correlation',
            estimator=np.mean,
            errorbar=('se', 0.95),
            data=df_all,
            hue='Model',
            palette=palette,
            legend=False
        )
        plt.xlabel("Model", fontsize=18)
        plt.ylabel("Mean Correlation", fontsize=18)
        plt.xticks(rotation=45, fontsize=16)
        plt.yticks(fontsize=16)
        plt.tight_layout()
        plt.savefig(self.plot_dir / "modelwise_barplot_cl_640.png", dpi=300)
        plt.savefig(self.plot_dir / "modelwise_barplot_cl_640.pdf", dpi=1200)
        plt.close()

        # Perform pairwise t-tests and store p-values
        significance_matrix = np.full((len(unique_models), len(unique_models)), np.nan)
        for i, model_a in enumerate(unique_models):
            for j, model_b in enumerate(unique_models):
                if i >= j:
                    continue
                df_a = df_all[df_all['Model'] == model_a]['Correlation']
                df_b = df_all[df_all['Model'] == model_b]['Correlation']
                _, p_value = pg.ttest(df_a, df_b, paired=True).iloc[0][['T', 'p-val']]
                significance_matrix[i, j] = p_value
                significance_matrix[j, i] = p_value

        # Heatmap for the significance matrix
        plt.figure(figsize=(8, 6))
        ax = sns.heatmap(
            significance_matrix,
            annot=True,
            fmt=".3f",
            cmap="coolwarm",
            xticklabels=unique_models,
            yticklabels=unique_models,
            mask=np.eye(len(unique_models), dtype=bool),  # Mask diagonal values
            annot_kws={"size": 14}  # Increase font size for annotations
        )
        ax.figure.axes[-1].set_ylabel('p-value', size=20)
        ax.figure.axes[-1].tick_params(labelsize=18)
        plt.xlabel("Model", fontsize=20)
        plt.ylabel("Model", fontsize=20)
        plt.xticks(rotation=45, fontsize=18)
        plt.yticks(fontsize=18)
        plt.tight_layout()
        plt.savefig(self.plot_dir / "modelwise_heatmap_cl_640.png", dpi=300)
        plt.savefig(self.plot_dir / "modelwise_heatmap_cl_640.pdf", dpi=1200)
        plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--models', '-m', type=str, nargs='+', help="Models to analyze.")
    parser.add_argument('--context-lengths', '-l',  type=int, nargs='+', default=[20, 40, 80, 160, 320, 640], help="Context lengths to analyze.")
    parser.add_argument('--dataset', '-d', type=str, default='HarryPotter', help="Name of the dataset [HarryPotter].")
    parser.add_argument('--voxel-func', '-f', default='mean', type=str, help="How to aggregate correlation values across voxels [mean, median].")
    parser.add_argument('--subj-func', '-s', default='mean', type=str, help="How to aggregate correlation values across subjects [mean, median].")
    parser.add_argument('overrides', nargs='*', help="Any key=svalue arguments to override config values")

    args =  parser.parse_args()
    context_lengths = args.context_lengths
    dataset = args.dataset
    models = args.models
    voxel_func = args.voxel_func
    subj_func = args.subj_func

    # Get directories
    root_dir = Path(os.getcwd())
    
    ### Get alignment data and generate plots
    alignment_data = AlignmentData(dataset, models, context_lengths, voxel_func, subj_func, root_dir)
    print("Loading data...")
    alignment_data.load()

    #print("Performing significance tests...")
    #if len(models) > 1:
    #    alignment_data.perform_significance_tests_across_models()
    #alignment_data.perform_significance_tests()
    #print("Plotting heatmaps...")
    #alignment_data.plot_layerwise_mean_corr_heathmap()
    #print("Plotting boxplots...")
    #alignment_data.plot_layerwise_boxplots()
    #print("Plotting lineplots...")
    #alignment_data.plot_rois_lineplot()
    #alignment_data.plot_layers_barplot_by_model_family()
    #alignment_data.plot_layers_lineplot_by_model_family()
    #alignment_data.plot_layers_lineplot()
    #alignment_data.plot_aggregated_lineplot()
    alignment_data.plot_modelwise_barplot()