# analysis_modules/ablation_study.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re
from typing import List, Dict
from .base_analyzer import BaseAnalyzer

class AblationStudyAnalyzer(BaseAnalyzer):
    """
    Handles the analysis for the 'ablation_study' experiment.
    """
    def _create_ablation_table(self, exp_name: str, datasets_to_run: Dict[str, str], taus_to_show: List[int]):
        metric_type = 'episode_wise'
        results = self.scan_experiments_for_metric_type(exp_name, metric_type)
        if not results:
            print(f"No results found for experiment '{exp_name}'.")
            return
        print(f"\n1. Generating Ablation Study Comparison Table (metric: {metric_type})...")
        variants = ['full_model', 'no_dr', 'no_her', 'with_cql']
        variant_map = {'full_model': 'Full Model', 'no_dr': 'w/o DR', 'no_her': 'w/o Her', 'with_cql': 'with CQL'}
        
        all_data = {}
        for pretty_name, data_path in datasets_to_run.items():
            path_parts = data_path.split('/')
            data_source = results
            try:
                for part in path_parts:
                    data_source = data_source[part]
                if 'gift' in data_source:
                    all_data[pretty_name] = data_source['gift']
            except KeyError:
                print(f" - Warning: Data for '{pretty_name}' at path '{data_path}' not found. Skipping.")
                continue
        
        if not all_data:
            print(" - No valid data found for any specified dataset. Aborting table creation.")
            return

        # Create DataFrame to hold the raw data strings
        columns = pd.MultiIndex.from_product(
            [list(datasets_to_run.keys()), [f'$\\tau={t}$' for t in taus_to_show]],
            names=['Dataset', 'Tau']
        )
        df = pd.DataFrame(index=[variant_map[v] for v in variants], columns=columns, dtype=object)

        # Populate DataFrame with absolute values (mean ± std)
        for pretty_name, gift_data in all_data.items():
            parsed_variants = {}
            for param_suffix, res in gift_data.items():
                variant_name_match = re.match(r'([a-zA-Z_]+)_shift_False', param_suffix)
                if variant_name_match and variant_name_match.group(1) in variants:
                    parsed_variants[variant_name_match.group(1)] = res
            
            for variant_key, variant_pretty_name in variant_map.items():
                variant_results = parsed_variants.get(variant_key)
                for tau in taus_to_show:
                    col_name = (pretty_name, f'$\\tau={tau}$')
                    tau_key = f'tau_{tau}'

                    if variant_results and tau_key in variant_results and not np.isnan(variant_results[tau_key]['mean']):
                        mean = variant_results[tau_key]['mean']
                        std = variant_results[tau_key]['std']
                        dataset_type = self._get_dataset_type(pretty_name)
                        df.loc[variant_pretty_name, col_name] = self._format_mean_std(mean, std, dataset_type)
                    else:
                        df.loc[variant_pretty_name, col_name] = 'N/A'

        # Helper to parse mean from 'mean ± std' string
        def parse_mean_from_str(value_str):
            if not isinstance(value_str, str) or '$\\pm$' not in value_str:
                return np.inf
            try:
                mean_part = value_str.split('$\\pm$')[0]
                return float(mean_part)
            except (ValueError, IndexError):
                return np.inf

        # Find the minimum in each column and bold it
        for col in df.columns:
            means = df[col].apply(parse_mean_from_str)
            if means.empty or means.isin([np.inf]).all():
                continue
            min_idx = means.idxmin()
            if df.loc[min_idx, col] != 'N/A':
                 df.loc[min_idx, col] = f"\\textbf{{{df.loc[min_idx, col]}}}"

        # Manually construct the LaTeX table string
        latex_parts = [
            "\\begin{table*}[t]",
            "    \\centering",
            "\\ small % Use\\ small font size for better readability",
            "    \\caption{Ablation study results for different model configurations. Lower RMSE values are better, with the best performance in each column highlighted in \\textbf{bold}.}",
            "    \\label{tab:ablation_study_v2}",
            "    \\vskip 0.1in",
        ]

        num_datasets = len(datasets_to_run)
        col_format = "l " + " ".join(["cc"] * num_datasets)
        latex_parts.append(f"    \\begin{{tabular}}{{{col_format}}}")
        latex_parts.append("    \\toprule")

        # Header row 1 (Dataset names)
        header1_items = ["~"]
        for ds_name in datasets_to_run.keys():
            header1_items.append(f"\\multicolumn{{2}}{{c}}{{{ds_name}}}")
        latex_parts.append(" & ".join(header1_items) + " \\\\")

        # cmidrule line
        cmidrules = []
        for i in range(num_datasets):
            start_col = 2 + i * 2
            end_col = start_col + 1
            cmidrules.append(f"\\cmidrule(lr){{{start_col}-{end_col}}}")
        latex_parts.append("    " + " ".join(cmidrules))

        # Header row 2 (Tau values)
        header2_items = ["~"]
        for _ in range(num_datasets):
            header2_items.extend([f"$\\tau={t}$" for t in taus_to_show])
        latex_parts.append(" & ".join(header2_items) + " \\\\")
        latex_parts.append("    \\midrule")

        # Data rows
        for index, row in df.iterrows():
            row_items = [index] + list(row.values)
            latex_parts.append("    " + " & ".join(map(str, row_items)) + " \\\\")

        # Table footer
        latex_parts.extend([
            "    \\bottomrule",
            "    \\end{tabular}",
            "    \\vskip -0.05in",
            "\\end{table*}"
        ])

        latex_table = "\n".join(latex_parts)

        # Save the generated LaTeX table
        tex_path = self._get_output_path(exp_name, 'table', metric_type, "ablation_study_comparison_v2.tex")
        with open(tex_path, "w") as f:
            f.write(latex_table)
        print(f" - Ablation table saved to: {tex_path}")

    def _create_ablation_csvs(self, exp_name: str, datasets_to_run: Dict[str, str]):
        metric_type = 'episode_wise'
        variant_map = {'full_model': 'Full Model', 'no_dr': 'w/o DR', 'no_recover': 'w/o Recover', 'no_dr_recover': 'w/o DR and Recover'}
        variants = ['full_model', 'no_dr', 'no_her', 'with_cql']
        variant_map = {'full_model': 'Full Model', 'no_dr': 'w/o DR', 'no_her': 'w/o Her', 'with_cql': 'with CQL'}
        results = self.scan_experiments_for_metric_type(exp_name, metric_type)
        if not results:
            return
        print(f"\n2. Generating Ablation Study CSVs (metric: {metric_type})...")
        for pretty_name, data_path in datasets_to_run.items():
            data_source = results
            try:
                for part in data_path.split('/'):
                    data_source = data_source[part]
                gift_data = data_source.get('gift')
                if not gift_data:
                    raise KeyError
            except KeyError:
                print(f" - Data for '{pretty_name}' not found. Skipping CSV.")
                continue
            parsed_variants = {}
            for param_suffix, res in gift_data.items():
                match = re.match(r'([a-zA-Z_]+)_shift_False', param_suffix)
                if match and match.group(1) in variants:
                    parsed_variants[match.group(1)] = res
            if not parsed_variants:
                print(f" - No valid variant data found for '{pretty_name}'.")
                continue

            #Fixed tau from 1 to 6
            tau_keys = [f'tau_{i}' for i in range(1, 7)]
            #Create DataFrame, row is variant, column is tau
            df = pd.DataFrame(index=[variant_map[v] for v in variants], columns=[str(i) for i in range(1, 7)])

            dataset_type = self._get_dataset_type(pretty_name)
            
            for variant_key, variant_name in variant_map.items():
                variant_data = parsed_variants.get(variant_key)
                for tau_key in tau_keys:
                    if variant_data and tau_key in variant_data:
                        mean = variant_data[tau_key]['mean']
                        std = variant_data[tau_key]['std']
                        df.at[variant_name, tau_key.split('_')[1]] = self._format_mean_std(mean, std, dataset_type)
                    else:
                        df.at[variant_name, tau_key.split('_')[1]] = 'N/A'

            #Filename cleanup without LaTeX symbols
            sanitized_name = pretty_name.replace('\\', '').replace('(', '_').replace(')', '').replace('=', '').replace('$', '')
            csv_path = self._get_output_path(exp_name, 'csv', metric_type, f"ablation_{sanitized_name}.csv")
            df.to_csv(csv_path)
            print(f" - CSV for '{pretty_name}' saved to: {csv_path}")

    def _create_ablation_plots(self, exp_name: str, datasets_to_run: Dict[str, str]):
        metric_type = 'episode_wise'
        results = self.scan_experiments_for_metric_type(exp_name, metric_type)
        if not results:
            return

        print(f"\n3. Generating Ablation Study Plots (metric: {metric_type})...")
        variants = ['full_model', 'no_dr', 'no_recover', 'no_dr_recover']
        variant_map = {'full_model': 'Full Model', 'no_dr': 'w/o DR', 'no_recover': 'w/o Recover', 'no_dr_recover': 'w/o DR and Recover'}
        variants = ['full_model', 'no_dr', 'no_her', 'with_cql']
        variant_map = {'full_model': 'Full Model', 'no_dr': 'w/o DR', 'no_her': 'w/o Her', 'with_cql': 'with CQL'}
        for pretty_name, data_path in datasets_to_run.items():
            data_source = results
            try:
                for part in data_path.split('/'):
                    data_source = data_source[part]
                gift_data = data_source.get('gift')
                if not gift_data:
                    raise KeyError
            except KeyError:
                print(f" - Data for '{pretty_name}' not found. Skipping plot.")
                continue

            parsed_variants = {}
            for param_suffix, res in gift_data.items():
                match = re.match(r'([a-zA-Z_]+)_shift_False', param_suffix)
                if match and match.group(1) in variants:
                    parsed_variants[match.group(1)] = res

            fig, ax = plt.subplots(1, 1, figsize=(10, 7))
            plot_created = False

            for i, variant_key in enumerate(variants):
                variant_data = parsed_variants.get(variant_key)
                if not variant_data:
                    continue
                tau_keys = sorted([k for k in variant_data.keys() if k.startswith('tau_')], key=lambda k: int(k.split('_')[1]))
                tau_range = [int(k.split('_')[1]) for k in tau_keys]
                means = [variant_data[k]['mean'] for k in tau_keys]
                stds = [variant_data[k]['std'] for k in tau_keys]
                if not means:
                    continue

                color = self.plot_colors[i % len(self.plot_colors)]
                ls = self.linestyles[i % len(self.linestyles)]
                marker = self.markers[i % len(self.markers)]

                ax.plot(tau_range, means, marker=marker, color=color, linestyle=ls, label=variant_map.get(variant_key, variant_key))
                ax.fill_between(tau_range, np.array(means) - np.array(stds), np.array(means) + np.array(stds), color=color, alpha=0.15)
                plot_created = True

            if not plot_created:
                plt.close(fig)
                continue

            ax.set_xlabel(r'$\tau$', fontsize=28)
            ax.set_ylabel('RMSE (Normalized)', fontsize=28)
            ax.set_title(f'GIFT Ablation Study on {pretty_name}', fontsize=24)
            ax.legend(fontsize=18)
            ax.grid(True, linestyle='--', alpha=0.5)
            plt.tight_layout()

            sanitized_name = pretty_name.replace('\\', '').replace('(', '_').replace(')', '').replace('=', '').replace('$', '')
            filename = f"ablation_{sanitized_name}_comparison.pdf"
            pdf_path = self._get_output_path(exp_name, 'figure', metric_type, filename)
            plt.savefig(pdf_path, dpi=300)
            plt.close(fig)
            print(f" - Plot for '{pretty_name}' saved to: {pdf_path}")


    def run_analysis(self, exp_name: str = "ablation_study"):
        print(f"\n--- Starting Analysis for Experiment: '{exp_name}' ---")
        datasets_to_run = {
            'MIMIC-III': 'mimic3_synthetic',
            'Tumor ($\\gamma=2$)': 'tumor_gamma=2',
            'Tumor ($\\gamma=4$)': 'tumor_gamma=4'
        }
        taus_to_show = [3, 6]
        self._create_ablation_table(exp_name, datasets_to_run, taus_to_show)
        self._create_ablation_csvs(exp_name, datasets_to_run)
        self._create_ablation_plots(exp_name, datasets_to_run)
        print(f"\n--- Analysis for '{exp_name}' complete. Outputs are in '{os.path.join(self.output_base_dir, exp_name)}' ---")