# analysis_modules/main_comparison.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re  #Make sure the re module is imported
from .base_analyzer import BaseAnalyzer

class MainComparisonAnalyzer(BaseAnalyzer):
    """
    Handles the analysis for the 'main_comparison' experiment,
    generating tables and plots for MIMIC and Tumor datasets, with separate
    outputs for different shift states.
    """
    def __init__(self, results_base_path: str = "results", output_base_dir: str = "analysis_output"):
        super().__init__(results_base_path, output_base_dir)
        self.shift_states = ['shift_False', 'shift_True']

    def _get_main_comp_output_path(self, exp_name: str, output_type: str, metric_type: str, shift_state: str, filename: str) -> str:
        """
        Constructs an output path that includes a subdirectory for the shift state.
        Example: analysis_output/main_comparison/table/episode_wise/shift_False/mimic3_comparison.tex
        """
        base_dir = os.path.dirname(self._get_output_path(exp_name, output_type, metric_type, 'dummy.txt'))
        shift_specific_dir = os.path.join(base_dir, shift_state)
        os.makedirs(shift_specific_dir, exist_ok=True)
        return os.path.join(shift_specific_dir, filename)

    def _extract_mean_from_str(self, value_str: str) -> float:
        """
        Helper to robustly extract the mean value from a 'mean ± std' string.
        """
        if not isinstance(value_str, str) or 'N/A' in value_str:
            return np.inf
        # Use regex to find the first floating point number at the start of the string.
        # This is robust to variations in spacing around pm.
        match = re.search(r'^\s*(-?\d+\.?\d*)', value_str)
        if match:
            try:
                return float(match.group(1))
            except ValueError:
                return np.inf
        return np.inf

    def create_comparison_table_real(self, exp_name: str, dataset_key: str = "mimic3_synthetic"):
        for metric_type in self.metrics_types:
            results = self.scan_experiments_for_metric_type(exp_name, metric_type)
            if not results.get(dataset_key):
                print(f"No results found for dataset '{dataset_key}' in '{exp_name}' for metric '{metric_type}'.")
                continue

            models_to_display = ['rmsn', 'crn', 'ct', 'actin', 'vcip', 'gift']
            tau_range = range(1, 7)
            table_columns = [f'$\\tau = {tau}$' for tau in tau_range]

            for shift_state in self.shift_states:
                print(f"Generating table for {dataset_key} ({shift_state}, {metric_type})...")
                table_data = {}
                found_data_for_shift = False

                for model in models_to_display:
                    model_upper = model.upper()
                    row_data = []
                    model_configs = results[dataset_key].get(model)
                    
                    data_for_shift = model_configs.get(shift_state) if model_configs else None
                    if not data_for_shift and shift_state == 'shift_False' and model_configs and 'default' in model_configs:
                        data_for_shift = model_configs.get('default')
                    
                    if data_for_shift:
                        found_data_for_shift = True
                        for tau in tau_range:
                            tau_key = f'tau_{tau}'
                            if tau_key in data_for_shift and not np.isnan(data_for_shift[tau_key]['mean']):
                                row_data.append(self._format_mean_std(data_for_shift[tau_key]['mean'], data_for_shift[tau_key]['std'], 'mimic'))
                            else:
                                row_data.append('N/A')
                    else:
                        row_data = ['N/A'] * len(tau_range)
                    
                    table_data[model_upper] = row_data

                if not found_data_for_shift:
                    print(f" - No data found for shift state '{shift_state}'. Skipping table.")
                    continue

                df = pd.DataFrame(table_data, index=table_columns).T
                
                #--- Code modification start: find the minimum and bold ---
                df_bold = df.copy()
                for col in df_bold.columns:
                    #Extract values using auxiliary methods
                    numeric_col = df_bold[col].apply(self._extract_mean_from_str)
                    #Action only if there is a valid numeric value in the column
                    if not numeric_col.empty and not np.all(np.isinf(numeric_col)):
                        min_idx = numeric_col.idxmin()
                        #Apply to cells with minimum values\ textbf
                        df_bold.loc[min_idx, col] = f"\\textbf{{{df_bold.loc[min_idx, col]}}}"
                
                caption_suffix = f" ({shift_state.replace('_', ' ')})"
                label_suffix = f"_{shift_state}"
                
                #Generate a LaTeX table with df_bold and add the position = 't!' parameter
                latex_table = df_bold.to_latex(
                    escape=False,
                    caption=f"{dataset_key} Performance (RMSE$\\pm$STD){caption_suffix}",
                    label=f"tab:{dataset_key}_{metric_type}{label_suffix}",
                    position='t!'
                )
                
                #Add the\ centering command after the start of the table environment
                latex_table = latex_table.replace('\\begin{table}[t!]', '\\begin{table}[t!]\n\\centering')
                #--- End of code modification ---
                
                output_filename_base = f"{dataset_key}_comparison"
                tex_path = self._get_main_comp_output_path(exp_name, 'table', metric_type, shift_state, f"{output_filename_base}.tex")
                csv_path = self._get_main_comp_output_path(exp_name, 'csv', metric_type, shift_state, f"{output_filename_base}.csv")

                with open(tex_path, "w") as f: f.write(latex_table)
                df.to_csv(csv_path)
                print(f" - Table saved to: {tex_path}")

    def create_tumor_comparison_table_real(self, exp_name: str):
        for metric_type in self.metrics_types:
            results = self.scan_experiments_for_metric_type(exp_name, metric_type)
            tumor_datasets = [key for key in results if 'tumor_gamma=' in key.lower()]
            if not tumor_datasets: continue

            for shift_state in self.shift_states:
                for dataset_key in tumor_datasets:
                    print(f"Generating table for {dataset_key} ({shift_state}, {metric_type})...")
                    models_to_display = ['rmsn', 'crn', 'ct', 'actin', 'vcip', 'gift']
                    
                    max_tau_for_shift = self._get_max_tau_length(results, dataset_key, models_to_display)
                    if max_tau_for_shift == 0: continue
                    
                    tau_range = range(1, max_tau_for_shift + 1)
                    table_columns = [f'$\\tau = {tau}$' for tau in tau_range]
                    table_data = {}
                    found_data_for_shift = False

                    for model in models_to_display:
                        model_upper = model.upper()
                        row_data = []
                        model_configs = results[dataset_key].get(model)
                        
                        data_for_shift = model_configs.get(shift_state) if model_configs else None
                        if not data_for_shift and shift_state == 'shift_False' and model_configs and 'default' in model_configs:
                            data_for_shift = model_configs.get('default')

                        if data_for_shift:
                            found_data_for_shift = True
                            for tau in tau_range:
                                tau_key = f'tau_{tau}'
                                if tau_key in data_for_shift and not (np.isnan(data_for_shift[tau_key]['mean']) or np.isnan(data_for_shift[tau_key]['std'])):
                                    row_data.append(self._format_mean_std(data_for_shift[tau_key]['mean'], data_for_shift[tau_key]['std'], 'tumor'))
                                else:
                                    row_data.append('N/A')
                        else:
                            row_data = ['N/A'] * len(tau_range)
                        table_data[model_upper] = row_data
                    
                    if not found_data_for_shift:
                        print(f" - No data found for '{dataset_key}' with shift state '{shift_state}'. Skipping table.")
                        continue

                    df = pd.DataFrame(table_data, index=table_columns).T

                    #--- Code modification start: find the minimum and bold ---
                    df_bold = df.copy()
                    for col in df_bold.columns:
                        #Extract values using auxiliary methods
                        numeric_col = df_bold[col].apply(self._extract_mean_from_str)
                        #Action only if there is a valid numeric value in the column
                        if not numeric_col.empty and not np.all(np.isinf(numeric_col)):
                            min_idx = numeric_col.idxmin()
                            #Apply to cells with minimum values\ textbf
                            df_bold.loc[min_idx, col] = f"\\textbf{{{df_bold.loc[min_idx, col]}}}"

                    caption_suffix = f" ({shift_state.replace('_', ' ')})"
                    label_suffix = f"_{shift_state}"

                    #Generate a LaTeX table with df_bold and add the position = 't!' parameter
                    latex_table = df_bold.to_latex(
                        escape=False,
                        caption=f"{dataset_key} Performance (RMSE$\\pm$STD){caption_suffix}",
                        label=f"tab:{dataset_key}_{metric_type}{label_suffix}",
                        position='t!'
                    )
                    
                    #Add the\ centering command after the start of the table environment
                    latex_table = latex_table.replace('\\begin{table}[t!]', '\\begin{table}[t!]\n\\centering')
                    #--- End of code modification ---
                    
                    output_filename_base = f"{dataset_key}_comparison"
                    tex_path = self._get_main_comp_output_path(exp_name, 'table', metric_type, shift_state, f"{output_filename_base}.tex")
                    csv_path = self._get_main_comp_output_path(exp_name, 'csv', metric_type, shift_state, f"{output_filename_base}.csv")
                    with open(tex_path, "w") as f: f.write(latex_table)
                    df.to_csv(csv_path)
                    print(f" - Table for '{dataset_key}' saved to: {tex_path}")

    def create_tumor_gamma_comparison_plots_real(self, exp_name: str):
        for metric_type in self.metrics_types:
            results = self.scan_experiments_for_metric_type(exp_name, metric_type)
            # Correctly identify tumor datasets based on your naming convention
            tumor_datasets_data = {k: v for k, v in results.items() if 'tumor_gamma=' in k.lower()}
            if not tumor_datasets_data:
                print(f"No tumor datasets found for metric type '{metric_type}'.")
                continue

            for shift_state in self.shift_states:
                print(f"\nGenerating plots for shift state: {shift_state} ({metric_type})...")
                
                # Group data by gamma value for this shift state
                gamma_data_for_shift = {}
                for dataset_key, dataset_content in tumor_datasets_data.items():
                    # --- THIS IS THE CRITICAL FIX ---
                    # Use regex to robustly find the gamma value after 'gamma='
                    match = re.search(r'gamma=(\d+)', dataset_key)
                    if not match:
                        print(f"Warning: Could not parse gamma value from '{dataset_key}'. Skipping.")
                        continue
                    gamma_val = int(match.group(1))
                    
                    if gamma_val not in gamma_data_for_shift:
                        gamma_data_for_shift[gamma_val] = {'key': dataset_key, 'content': dataset_content}

                if not gamma_data_for_shift:
                    print(f" - No data to plot for shift state '{shift_state}'.")
                    continue

                for gamma, data_bundle in sorted(gamma_data_for_shift.items()):
                    dataset_key = data_bundle['key']
                    dataset_results = data_bundle['content']

                    fig, ax = plt.subplots(1, 1, figsize=(10, 7), constrained_layout=True)
                    models_to_display = ['rmsn', 'crn', 'ct', 'actin', 'vcip', 'gift']
                    model_count = 0

                    for i, model in enumerate(models_to_display):
                        model_configs = dataset_results.get(model, {})
                        model_data_for_shift = model_configs.get(shift_state)
                        
                        if not model_data_for_shift and shift_state == 'shift_False' and 'default' in model_configs:
                            model_data_for_shift = model_configs.get('default')

                        if not model_data_for_shift: continue

                        tau_keys = sorted([k for k in model_data_for_shift if k.startswith('tau_')], key=lambda k: int(k.split('_')[1]))
                        if not tau_keys: continue

                        tau_range = [int(k.split('_')[1]) for k in tau_keys]
                        means = [model_data_for_shift[k]['mean'] for k in tau_keys]
                        stds = [model_data_for_shift[k]['std'] for k in tau_keys]
                        
                        color, ls, marker = self.plot_colors[i], self.linestyles[i], self.markers[i]
                        ax.plot(tau_range, means, marker=marker, color=color, linestyle=ls, label=model.upper())
                        ax.fill_between(tau_range, np.array(means) - np.array(stds), np.array(means) + np.array(stds), color=color, alpha=0.15)
                        ax.tick_params(axis='x', labelsize=22)  #< --- Change X-axis scale label size
                        ax.tick_params(axis='y', labelsize=22)  #< --- Change Y-axis scale label size (optional)
                        model_count += 1
                    
                    if model_count == 0:
                        print(f" - No model data found to plot for gamma={gamma} on dataset '{dataset_key}'.")
                        plt.close(fig)
                        continue

                    ax.set_xlabel(r'$\tau$', fontsize=45)
                    ax.set_ylabel('RMSE (Normalized)', fontsize=35)
                    title_suffix = f" ({shift_state.replace('_', ' ')})"
                    # ax.set_title(f'Tumor Dataset ($\\gamma={gamma}$){title_suffix}', fontsize=24)
                    ax.legend(fontsize=20, loc='upper right')
                    ax.grid(True, linestyle='--', alpha=0.5)
                    # plt.tight_layout()
                    
                    filename = f"tumor_gamma_{gamma}_comparison.pdf"
                    pdf_path = self._get_main_comp_output_path(exp_name, 'figure', metric_type, shift_state, filename)
                    plt.savefig(pdf_path, dpi=300)
                    plt.close(fig)
                    print(f" - Plot for gamma={gamma} saved to: {pdf_path}")

    def run_analysis(self, exp_name: str = "main_comparison"):
        """
        Runs the full analysis suite for the 'main_comparison' experiment.
        """
        print(f"\n--- Starting Analysis for Experiment: '{exp_name}' ---")
        print("\n1. Generating MIMIC Data Comparison Tables...")
        self.create_comparison_table_real(exp_name=exp_name, dataset_key="mimic3_synthetic")
        print("\n2. Generating Tumor Data Comparison Tables...")
        self.create_tumor_comparison_table_real(exp_name=exp_name)
        print("\n3. Generating Tumor Gamma Comparison Plots...")
        self.create_tumor_gamma_comparison_plots_real(exp_name=exp_name)
        print(f"\n--- Analysis for '{exp_name}' complete. Outputs are in '{os.path.join(self.output_base_dir, exp_name)}' ---")