# analysis_modules/goal_threshold.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re
from typing import List
from .base_analyzer import BaseAnalyzer

class GoalThresholdAnalyzer(BaseAnalyzer):
    """
    Handles the analysis for the 'goal_threshold_study' experiment.
    """
    def _create_goal_threshold_tables(self, exp_name: str):
        # ... (Copy the _create_goal_threshold_tables method code here)
        metric_type = 'episode_wise'
        results = self.scan_experiments_for_metric_type(exp_name, metric_type)
        if not results:
            print(f"No results found for experiment '{exp_name}'.")
            return
        print(f"\n1. Generating Goal Threshold Comparison Tables (metric: {metric_type})...")
        for dataset_key, model_results in results.items():
            if 'gift' not in model_results or not model_results['gift']:
                continue
            parsed_gt_data = {}
            for param_suffix, data in model_results['gift'].items():
                match = re.search(r'([0-9\.eE\-\+]+)', str(param_suffix))
                if match:
                    gt_value_str = match.group(1)
                    parsed_gt_data[gt_value_str] = data
                else:
                    print(f" - Warning: Could not extract a numerical gt from dir '{param_suffix}'. Skipping.")
                    continue
            if not parsed_gt_data:
                print(f" - No valid goal threshold data found for '{dataset_key}' after parsing.")
                continue
            gift_data = parsed_gt_data
            try:
                goal_thresholds = sorted(gift_data.keys(), key=lambda x: float(x))
            except (ValueError, TypeError):
                print(f" - Warning: Could not sort goal thresholds numerically for '{dataset_key}'. Using string sort.")
                goal_thresholds = sorted(gift_data.keys())
            all_taus = set()
            for gt_data in gift_data.values():
                all_taus.update(gt_data.keys())
            if not all_taus: continue
            tau_keys = sorted([t for t in all_taus if t.startswith('tau_')], key=lambda t: int(t.split('_')[1]))
            tau_numbers = [int(t.split('_')[1]) for t in tau_keys]
            df = pd.DataFrame(index=[f'$\\tau = {t}$' for t in tau_numbers], columns=goal_thresholds)
            dataset_type = self._get_dataset_type(dataset_key)
            for tau_key, tau_num in zip(tau_keys, tau_numbers):
                for gt in goal_thresholds:
                    tau_data = gift_data.get(gt, {}).get(tau_key)
                    if tau_data and not (np.isnan(tau_data['mean']) or np.isnan(tau_data['std'])):
                        df.loc[f'$\\tau = {tau_num}$', gt] = self._format_mean_std(tau_data['mean'], tau_data['std'], dataset_type)
                    else:
                        df.loc[f'$\\tau = {tau_num}$', gt] = 'N/A'
            df.columns = [f"{float(c):.4f}" if self._get_dataset_type(dataset_key) == 'tumor' else f"{float(c):.1f}" for c in df.columns]
            latex_table = df.to_latex(escape=False, caption=f"GIFT Performance vs. Goal Threshold on {dataset_key} (RMSE$\\pm$STD)", label=f"tab:gt_{dataset_key}")
            tex_path = self._get_output_path(exp_name, 'table', metric_type, f"{dataset_key}_gt_comparison.tex")
            csv_path = self._get_output_path(exp_name, 'csv', metric_type, f"{dataset_key}_gt_comparison.csv")
            with open(tex_path, "w") as f: f.write(latex_table)
            df.to_csv(csv_path)
            print(f" - Table for '{dataset_key}' saved to: {tex_path}")
            print(f" - CSV for '{dataset_key}' saved to: {csv_path}")


    def _create_goal_threshold_plots(self, exp_name: str, taus_to_plot: List[int]):
        # ... (Copy the _create_goal_threshold_plots method code here)
        metric_type = 'episode_wise'
        results = self.scan_experiments_for_metric_type(exp_name, metric_type)
        if not results: return
        print(f"\n2. Generating Goal Threshold Comparison Plots (metric: {metric_type})...")
        for dataset_key, model_results in results.items():
            if 'gift' not in model_results or not model_results['gift']:
                continue
            parsed_gt_data = {}
            for param_suffix, data in model_results['gift'].items():
                match = re.search(r'([0-9\.eE\-\+]+)', str(param_suffix))
                if match:
                    parsed_gt_data[match.group(1)] = data
                else:
                    print(f" - Warning: Could not extract numerical gt from dir '{param_suffix}'. Skipping plot generation for this dir.")
            if not parsed_gt_data:
                print(f" - No valid goal threshold data to plot for '{dataset_key}'.")
                continue
            gift_data = parsed_gt_data
            try:
                goal_thresholds_str = sorted(gift_data.keys(), key=lambda x: float(x))
                goal_thresholds_num = [float(gt) for gt in goal_thresholds_str]
            except (ValueError, TypeError):
                print(f" - Warning: Could not sort goal thresholds numerically for '{dataset_key}'. Skipping plot.")
                continue
            fig, ax = plt.subplots(1, 1, figsize=(12, 7))
            plot_created = False
            for i, tau_num in enumerate(taus_to_plot):
                tau_key = f'tau_{tau_num}'
                means, stds = [], []
                for gt in goal_thresholds_str:
                    tau_data = gift_data.get(gt, {}).get(tau_key)
                    if tau_data and not (np.isnan(tau_data['mean']) or np.isnan(tau_data['std'])):
                        means.append(tau_data['mean'])
                        stds.append(tau_data['std'])
                    else:
                        means.append(np.nan)
                        stds.append(np.nan)
                if pd.Series(means).isnull().all():
                    continue
                color = self.plot_colors[i % len(self.plot_colors)]
                ls = self.linestyles[i % len(self.linestyles)]
                marker = self.markers[i % len(self.markers)]
                ax.plot(goal_thresholds_num, means, color=color, linestyle=ls, marker=marker, label=f'$\\tau = {tau_num}$')
                means_interp = np.array(pd.Series(means).interpolate(method='linear', limit_direction='both'))
                stds_interp = np.array(pd.Series(stds).interpolate(method='linear', limit_direction='both'))
                ax.fill_between(goal_thresholds_num, means_interp - stds_interp, means_interp + stds_interp, color=color, alpha=0.15)
                ax.tick_params(axis='x', labelsize=24)  #< --- Change X-axis scale label size
                ax.tick_params(axis='y', labelsize=24)  #< --- Change Y-axis scale label size (optional)
                plot_created = True
            if not plot_created:
                plt.close(fig)
                continue
            ax.set_xlabel('Hit Ratio', fontsize=35)
            if 'tumor' in dataset_key.lower():
                ax.set_ylabel('RMSE (Normalized)', fontsize=35)
            else:
                ax.set_ylabel('RMSE', fontsize=35)
            # ax.set_title(f'GIFT Performance on {dataset_key}', fontsize=24)
            ax.legend(fontsize=24)
            ax.grid(True, linestyle='--', alpha=0.5)
            # if 'tumor' in dataset_key.lower():
            #     ax.set_xscale('log')
            plt.tight_layout()
            filename = f"{dataset_key}_gt_comparison.pdf"
            pdf_path = self._get_output_path(exp_name, 'figure', metric_type, filename)
            plt.savefig(pdf_path, dpi=300)
            plt.close(fig)
            print(f" - Plot for '{dataset_key}' (all taus) saved to: {pdf_path}")

    def run_analysis(self, exp_name: str = "goal_threshold_study", taus_to_plot: List[int] = [1, 2, 4, 6]):
        """
        Runs the analysis suite for the 'goal_threshold_study' experiment.
        """
        print(f"\n--- Starting Analysis for Experiment: '{exp_name}' ---")
        self._create_goal_threshold_tables(exp_name)
        self._create_goal_threshold_plots(exp_name, taus_to_plot=taus_to_plot)
        print(f"\n--- Analysis for '{exp_name}' complete. Outputs are in '{os.path.join(self.output_base_dir, exp_name)}' ---")

