import pandas as pd
import numpy as np
import os
from .base_analyzer import BaseAnalyzer
from typing import List, Dict, Tuple, Optional

class OptimizeStudyAnalyzer(BaseAnalyzer):
    """
    Analyzer designed to compare the step-wise and episode-wise optimization strategies of different models.

    The analyzer generates a LaTeX table for each dataset, where:
    - Compare the two strategies of the 'rmsn', 'crn', 'ct', 'actin', 'vcip' model.
    - The 'gift' model is presented separately as a benchmark.
    - Use\ midrule separation between different model groups.
    - The best results (minimum RMSE) in each column will be bolded.
    - Reads data from one experiment (e.g. 'main_comparison') and outputs the results to another analysis catalog (e.g. 'optimize_study').
    """

    def __init__(self, results_base_path: str = "results", output_base_dir: str = "analysis_output"):
        """
        Initializes the analyzer.
        """
        super().__init__(results_base_path, output_base_dir)

    def _get_optimize_output_path(self, exp_name: str, output_type: str, dataset_key: str, filename: str) -> str:
        """
        Build a specific output path for optimize_study.
        Example: analysis_output/optimize_study/table/mimic3_synthetic/comparison.tex
        """
        output_directory = os.path.join(self.output_base_dir, exp_name, output_type, dataset_key)
        os.makedirs(output_directory, exist_ok=True)
        return os.path.join(output_directory, filename)

    def _extract_model_data(self, results: Dict, dataset_key: str, model: str, tau_keys: List[str], dataset_type: str) -> Tuple[Optional[Dict], Optional[Dict]]:
        """
        Auxiliary functions for extracting formatted performance data and raw mean data for a single model.
        """
        row_data = {}
        row_means = {}
        model_configs = results.get(dataset_key, {}).get(model, {})
        
        data_source = None
        if 'default' in model_configs:
            data_source = model_configs['default']
        elif model_configs:
            data_source = model_configs.get('shift_False', next(iter(model_configs.values()), None))

        if data_source:
            for i, tau_key in enumerate(tau_keys):
                col_name = f'$\\tau = {i+1}$'
                if tau_key in data_source and not np.isnan(data_source[tau_key]['mean']):
                    mean = data_source[tau_key]['mean']
                    std = data_source[tau_key]['std']
                    row_data[col_name] = self._format_mean_std(mean, std, dataset_type)
                    row_means[col_name] = mean
                else:
                    row_data[col_name] = 'N/A'
                    row_means[col_name] = np.inf
            return row_data, row_means
        
        return None, None

    def _get_max_tau_for_dataset(self, results_list: List[Dict], dataset_key: str, models: List[str]) -> int:
        """
        A helper function to find the maximum tau length for a particular dataset in multiple result dictionaries.
        """
        max_tau = 0
        for results in results_list:
            dataset_results = results.get(dataset_key, {})
            for model in models:
                if model in dataset_results:
                    for config_data in dataset_results[model].values():
                        tau_keys = [k for k in config_data.keys() if k.startswith('tau_')]
                        if tau_keys:
                            tau_numbers = [int(k.replace('tau_', '')) for k in tau_keys]
                            if tau_numbers:
                                max_tau = max(max_tau, max(tau_numbers))
        return max_tau

    def _to_latex_with_midrules(self, df: pd.DataFrame, dataset_key: str) -> str:
        """
        A custom LaTeX table generation function to add\ midrule between different model groups.
        """
        latex_parts = []
        
        #--- Modify Point 3: Replace the underscore in the dataset name with a space for the title display ---
        display_dataset_key = dataset_key.replace('_', ' ')
        caption = f"Performance Comparison of Optimization Strategies on {display_dataset_key} (RMSE$\\pm$STD)"
        if 'tumor' in display_dataset_key:
            tumor = display_dataset_key.split('=', 1)[-1]
            caption = f"Performance comparison of GIFT and baseline models under various optimization strategies on the Tumor dataset ($ \gamma={tumor}$)."
        else:
            caption = "Performance comparison of GIFT and baseline models under various optimization strategies on the MIMIC-III dataset."
        label = f"tab:{dataset_key}_optimize_comp"
        col_format = f"{{ll{'c'*len(df.columns)}}}" 
        
        latex_parts.append("\\begin{table}[t]") 
        latex_parts.append("\\centering")
        latex_parts.append(f"\\caption{{{caption}}}")
        latex_parts.append(f"\\label{{{label}}}")
        latex_parts.append(f"\\begin{{tabular}}{{{col_format}}}")
        latex_parts.append("\\toprule")
        
        header_cols = ['Model', 'Strategy'] + list(df.columns)
        header = " & ".join(header_cols) + " \\\\"
        latex_parts.append(header)
        latex_parts.append("\\midrule")

        if not df.empty:
            last_model = df.index.get_level_values('Model')[0]
            for (model, strategy), row in df.iterrows():
                if model != last_model:
                    latex_parts.append("\\midrule")
                
                row_str = f"{model} & {strategy} & " + " & ".join(map(str, row.values)) + " \\\\"
                latex_parts.append(row_str)
                last_model = model

        latex_parts.append("\\bottomrule")
        latex_parts.append("\\end{tabular}")
        latex_parts.append("\\end{table}")

        return "\n".join(latex_parts)

    def create_optimization_strategy_table(self, output_exp_name: str, input_exp_name: str):
        """
        Create a comparison table for the 'optimize_study' experiment.
        """
        print(f"1. Scanning input experiment '{input_exp_name}' for step-wise and episode-wise metrics...")
        step_wise_results = self.scan_experiments_for_metric_type(input_exp_name, 'step_wise')
        episode_wise_results = self.scan_experiments_for_metric_type(input_exp_name, 'episode_wise')

        all_datasets = sorted(list(set(step_wise_results.keys()) | set(episode_wise_results.keys())))
        
        models_to_compare = ['rmsn', 'crn', 'ct', 'actin', 'vcip']
        gift_model = 'gift'

        print("2. Generating comparison tables for each dataset...")
        for dataset_key in all_datasets:
            #--- Modification point 1: Remove logic that skips the tumor dataset ---
            if not dataset_key:
                continue
            print(f" - Processing dataset: {dataset_key}...")

            table_data = []
            raw_means = []

            max_tau = self._get_max_tau_for_dataset(
                [step_wise_results, episode_wise_results], dataset_key, models_to_compare + [gift_model]
            )
            if max_tau == 0:
                print(f"   - No data found for dataset '{dataset_key}'. Skipping.")
                continue
            
            tau_range = range(1, max_tau + 1)
            tau_keys = [f'tau_{tau}' for tau in tau_range]
            table_columns = [f'$\\tau = {tau}$' for tau in tau_range]
            dataset_type = self._get_dataset_type(dataset_key)

            for model in models_to_compare:
                step_row_data, step_row_means = self._extract_model_data(step_wise_results, dataset_key, model, tau_keys, dataset_type)
                if step_row_data:
                    #--- Modification point 2: Remove brackets from policy name ---
                    table_data.append({'Model': model.upper(), 'Strategy': 'step', **step_row_data})
                    raw_means.append({'Model': model.upper(), 'Strategy': 'step', **step_row_means})

                episode_row_data, episode_row_means = self._extract_model_data(episode_wise_results, dataset_key, model, tau_keys, dataset_type)
                if episode_row_data:
                    #--- Modification point 2: Remove brackets from policy name ---
                    table_data.append({'Model': model.upper(), 'Strategy': 'episode', **episode_row_data})
                    raw_means.append({'Model': model.upper(), 'Strategy': 'episode', **episode_row_means})

            gift_row_data, gift_row_means = self._extract_model_data(episode_wise_results, dataset_key, gift_model, tau_keys, dataset_type)
            if gift_row_data:
                table_data.append({'Model': gift_model.upper(), 'Strategy': '', **gift_row_data})
                raw_means.append({'Model': gift_model.upper(), 'Strategy': '', **gift_row_means})

            if not table_data:
                print(f"   - No model data compiled for '{dataset_key}'. Skipping.")
                continue

            df_display = pd.DataFrame(table_data).set_index(['Model', 'Strategy'])
            df_means = pd.DataFrame(raw_means).set_index(['Model', 'Strategy'])
            
            df_display = df_display.reindex(columns=table_columns).fillna('N/A')
            df_means = df_means.reindex(columns=table_columns).fillna(np.inf)

            for col in table_columns:
                if not df_means[col].empty and df_means[col].min() != np.inf:
                    min_val_idx = df_means[col].idxmin()
                    original_value = df_display.loc[min_val_idx, col]
                    if original_value != 'N/A':
                        df_display.loc[min_val_idx, col] = f"\\textbf{{{original_value}}}"

            latex_string = self._to_latex_with_midrules(df_display, dataset_key)
            
            output_filename = f"{dataset_key}_optimize_comparison.tex"
            tex_path = self._get_optimize_output_path(output_exp_name, 'table', dataset_key, output_filename)
            with open(tex_path, "w") as f:
                f.write(latex_string)
            print(f"   - Table saved to: {tex_path}")

    def run_analysis(self, output_exp_name: str = "optimize_study", input_exp_name: str = "main_comparison"):
        """
        Run the complete analysis suite for the 'optimize_study' experiment.
        Load data from 'input_exp_name' and save the results to the 'output_exp_name' directory.
        """
        print(f"\n--- Starting Analysis for '{output_exp_name}' ---")
        print(f"--- Reading data from experiment: '{input_exp_name}' ---")
        self.create_optimization_strategy_table(output_exp_name=output_exp_name, input_exp_name=input_exp_name)
        print(f"\n--- Analysis for '{output_exp_name}' complete. Outputs are in '{os.path.join(self.output_base_dir, output_exp_name)}' ---")