# analysis_modules/base_analyzer.py

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import re
from pathlib import Path
import json
from typing import Dict, List, Tuple, Optional

# --- Matplotlib and LaTeX Configuration ---
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{times}'

class BaseAnalyzer:
    """
    Provides the core functionality for analyzing experimental results, including
    data loading, path management, and result aggregation. This class is intended
    to be inherited by specific analysis modules.
    """
    def __init__(self, results_base_path: str = "results", output_base_dir: str = "analysis_output"):
        """
        Initializes the analyzer with paths to the raw results and the base output directory.
        """
        self.results_base_path = results_base_path
        self.output_base_dir = output_base_dir
        self.decimal_places = {'mimic': 2, 'tumor': 2}
        self.metrics_types = {
            'step_wise': {
                'models': ['vcip', 'rmsn', 'crn', 'ct', 'actin', 'gift'],
                'folder_suffix': 'step_wise'
            },
            'episode_wise': {
                'models': ['vcip', 'rmsn', 'crn', 'ct', 'actin', 'gift'],
                'folder_suffix': 'episode_wise'
            }
        }

        # Plotting styles
        self.plot_colors = ['#E53935', '#AB47BC', '#7CB342', '#FB8C00', '#1E88E5', '#26A69A', '#FF7043', '#785EF0']
        self.linestyles = ['-', '--', '-.', (0, (5, 10, 3)), (0, (5, 10)), (0, (5, 1)), ':', (0, (3, 5, 1, 5))]
        self.markers = ['o', 's', '^', 'D', 'x', '*', 'p', 'h']

    def _get_output_path(self, exp_name: str, output_type: str, metric_type: str, filename: str) -> str:
        """
        Constructs the output file path and ensures the target directory exists.
        """
        folder_suffix = self.metrics_types[metric_type]['folder_suffix']
        output_directory = os.path.join(self.output_base_dir, exp_name, output_type, folder_suffix)
        os.makedirs(output_directory, exist_ok=True)
        return os.path.join(output_directory, filename)

    def _get_csv_filename(self, model: str, metric_type: str) -> str:
        """Determines the correct CSV filename based on model and metric type."""
        model_lower = model.lower()
        if model_lower == 'gift':
            return 'avg_rmse.csv'
        if metric_type == 'step_wise':
            return 'avg_rmse.csv'
        elif metric_type == 'episode_wise':
            return 'mse.csv'
        else:
            print(f"Warning: Unknown metric_type '{metric_type}' for model '{model}'. Defaulting to 'mse.csv'.")
            return 'mse.csv'

    def _standardize_column_names(self, df: pd.DataFrame, csv_type: str) -> pd.DataFrame:
        """Standardizes DataFrame column names (e.g., 'tau=1' to 'tau_1')."""
        df = df.copy()
        if csv_type == 'avg_rmse':
            return df
        column_mapping = {col: col.replace('tau=', 'tau_') for col in df.columns if col.startswith('tau=')}
        df = df.rename(columns=column_mapping)
        return df

    def load_experiment_data(self, exp_name: str, dataset: str, model: str,
                             metric_type: str, param_suffix: Optional[str] = None) -> Optional[pd.DataFrame]:
        """Loads and preprocesses data from a single experiment's CSV file."""
        base_path_parts = [self.results_base_path, exp_name, dataset, model]
        if param_suffix:
            base_path_parts.append(param_suffix)
        data_path = os.path.join(*base_path_parts, "raw_results")
        csv_filename = self._get_csv_filename(model, metric_type)
        csv_path = os.path.join(data_path, csv_filename)
        if not os.path.exists(csv_path):
            return None
        try:
            df = pd.read_csv(csv_path)
            csv_type = 'avg_rmse' if 'avg_rmse' in csv_filename else 'mse'
            df = self._standardize_column_names(df, csv_type)
            return df
        except Exception as e:
            print(f"Failed to read data from {csv_path}: {e}")
            return None

    def aggregate_seeds_data(self, df: pd.DataFrame) -> Dict[str, Dict[str, float]]:
        """Aggregates data from multiple seeds, calculating mean and standard deviation."""
        if df is None or df.empty:
            return {}
        tau_columns = [col for col in df.columns if col.startswith('tau_')]
        aggregated = {
            col: {'mean': df[col].mean(), 'std': df[col].std()}
            for col in tau_columns
        }
        return aggregated

    def scan_experiments_for_metric_type(self, exp_name: str, metric_type: str) -> Dict:
        """Scans all results for a given experiment name and metric type."""
        exp_path = os.path.join(self.results_base_path, exp_name)
        if not os.path.exists(exp_path):
            print(f"Experiment directory not found: {exp_path}")
            return {}
        results = {}
        for dataset in os.listdir(exp_path):
            dataset_path = os.path.join(exp_path, dataset)
            if not os.path.isdir(dataset_path): continue

            # Simplified handling for nested datasets
            current_level_results = results
            if '/' in dataset:
                parts = dataset.split('/')
                for part in parts[:-1]:
                    current_level_results = current_level_results.setdefault(part, {})
                dataset_key = parts[-1]
            else:
                dataset_key = dataset
            
            current_level_results[dataset_key] = {}
            current_results_node = current_level_results[dataset_key]

            for model in os.listdir(dataset_path):
                model_path = os.path.join(dataset_path, model)
                if not os.path.isdir(model_path) or model not in self.metrics_types[metric_type]['models']:
                    continue
                current_results_node[model] = {}
                param_subdirs = [d for d in os.listdir(model_path) if os.path.isdir(os.path.join(model_path, d, "raw_results"))]

                if param_subdirs:
                    for param_suffix in param_subdirs:
                        df = self.load_experiment_data(exp_name, dataset, model, metric_type, param_suffix)
                        if df is not None:
                            current_results_node[model][param_suffix] = self.aggregate_seeds_data(df)
                elif os.path.isdir(os.path.join(model_path, "raw_results")):
                    df = self.load_experiment_data(exp_name, dataset, model, metric_type)
                    if df is not None:
                        current_results_node[model]['default'] = self.aggregate_seeds_data(df)
        return results

    def _get_max_tau_length(self, results: Dict, dataset_key: str, available_models: List[str]) -> int:
        max_tau = 0
        dataset_results = results.get(dataset_key, {})
        for model in available_models:
            if model in dataset_results:
                for config_data in dataset_results[model].values():
                    tau_keys = [k for k in config_data.keys() if k.startswith('tau_')]
                    if tau_keys:
                        tau_numbers = [int(k.replace('tau_', '')) for k in tau_keys]
                        if tau_numbers:
                            max_tau = max(max_tau, max(tau_numbers))
        return max_tau

    def _format_mean_std(self, mean: float, std: float, dataset_type: str) -> str:
        dp = self.decimal_places.get(dataset_type, 2)
        return f"{{:.{dp}f}}$\\pm${{:.{dp}f}}".format(mean, std)

    def _format_percentage(self, value: float) -> str:
        return f"{value:+.1f}\\%"

    def _get_dataset_type(self, dataset_name: str) -> str:
        return 'tumor' if 'tumor' in dataset_name.lower() else 'mimic'

