import os
import re
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from typing import List, Optional
from .base_analyzer import BaseAnalyzer

class KParameterStudyAnalyzer(BaseAnalyzer):

    def _get_models(self, exp_name: str, dataset_key: str, models: Optional[List[str]] = None):
        base_path = os.path.join(self.results_base_path, exp_name, dataset_key)
        if models is not None:
            return models
        try:
            model_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
            return model_dirs
        except Exception as e:
            print(f"Failed to list models in {base_path}: {e}")
            return []

    def scan_experiments_for_metric_type(self, exp_name: str, metric_type: str):
        #Directly call the base class scan_experiments_for_metric_type to reuse the existing complete structure
        return super().scan_experiments_for_metric_type(exp_name, metric_type)

    def _plot_metric_for_model(self, exp_name: str, dataset_key: str, model: str,
                           metric_type: str, taus_to_plot: List[int]):
        results = self.scan_experiments_for_metric_type(exp_name, metric_type)
        if not results:
            print(f"No results found for experiment '{exp_name}', metric '{metric_type}'.")
            return
        if dataset_key not in results or model not in results[dataset_key]:
            print(f"No data found for dataset '{dataset_key}', model '{model}'.")
            return

        model_results = results[dataset_key][model]

        #Parse parameter values as X axis
        parsed_param_data = {}
        for param_suffix, data in model_results.items():
            match = re.search(r'([0-9\.eE\-\+]+)', param_suffix)
            if match:
                parsed_param_data[match.group(1)] = data
            else:
                print(f"Warning: Could not parse numerical parameter from '{param_suffix}', skipping.")

        if not parsed_param_data:
            print(f"No valid parameter data found for dataset '{dataset_key}', model '{model}'.")
            return

        try:
            param_values_str = sorted(parsed_param_data.keys(), key=lambda x: float(x))
            param_values_float = [float(x) for x in param_values_str]
        except Exception as e:
            print(f"Warning: Could not sort parameters numerically for dataset '{dataset_key}', model '{model}'. Skipping plot.")
            return

        ratio = 0.85
        plt.figure(figsize=(12*ratio, 7*ratio))
        plot_created = False

        for i, tau_num in enumerate(taus_to_plot):
            tau_key = f'tau_{tau_num}'
            means, stds = [], []
            for param_str in param_values_str:
                tau_data = parsed_param_data[param_str].get(tau_key)
                if tau_data and not (np.isnan(tau_data.get('mean', np.nan)) or np.isnan(tau_data.get('std', np.nan))):
                    means.append(tau_data['mean'])
                    stds.append(tau_data['std'])
                else:
                    means.append(np.nan)
                    stds.append(np.nan)
            if all(np.isnan(means)):
                continue

            means_series = np.array(means)
            stds_series = np.array(stds)

            color = self.plot_colors[i % len(self.plot_colors)]
            ls = self.linestyles[i % len(self.linestyles)]
            marker = self.markers[i % len(self.markers)]

            plt.plot(param_values_float, means_series, color=color, linestyle=ls,
                    marker=marker, label=f"$\\tau = {tau_num}$")

            #Use shadows to represent std, linear interpolation to fill nan values
            means_interp = pd.Series(means_series).interpolate(method='linear', limit_direction='both').to_numpy()
            stds_interp = pd.Series(stds_series).interpolate(method='linear', limit_direction='both').to_numpy()
            plt.fill_between(param_values_float,
                            means_interp - stds_interp,
                            means_interp + stds_interp,
                            color=color, alpha=0.15)

            plot_created = True

        if not plot_created:
            plt.close()
            print(f"No plot data for model '{model}', dataset '{dataset_key}', metric '{metric_type}'")
            return

        plt.xlabel("Inference Optimization Steps", fontsize=28)
        if 'tumor' in dataset_key:
            plt.ylabel('RMSE (Normalized)', fontsize=28)
        else:
            plt.ylabel('RMSE', fontsize=28)
        # plt.title(f"{model.upper()} - {dataset_key} - {metric_type.replace('_', ' ').title()}", fontsize=30)
        plt.legend(fontsize=22, loc='upper right')
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.xticks(fontsize=24)
        plt.yticks(fontsize=24)
        plt.tight_layout()

        output_dir = os.path.join(self.output_base_dir, exp_name, "figure", metric_type, dataset_key)
        os.makedirs(output_dir, exist_ok=True)
        filename = f"{model}_parameter_{metric_type}.pdf"
        save_path = os.path.join(output_dir, filename)
        plt.savefig(save_path, dpi=300)
        plt.close()
        print(f" - Plot saved to {save_path}")

    def run_analysis(self, exp_name: str = "k_parameter_study", datasets: Optional[List[str]] = None,
                     metric_types: Optional[List[str]] = None, taus_to_plot: List[int] = [1, 2, 4, 6],
                     models: Optional[List[str]] = None):
        if datasets is None:
            #Automatically scan dataset catalogs
            exp_path = os.path.join(self.results_base_path, exp_name)
            if not os.path.isdir(exp_path):
                print(f"Experiment path does not exist: {exp_path}")
                return
            datasets = [d for d in os.listdir(exp_path) if os.path.isdir(os.path.join(exp_path, d))]
            if not datasets:
                print(f"No datasets found under {exp_path}")
                return

        if metric_types is None:
            metric_types = ['episode_wise', 'step_wise']

        print(f"\n--- Starting k_parameter_study Analysis for experiment '{exp_name}' ---")
        for metric_type in metric_types:
            print(f"\nProcessing metric type: {metric_type}")
            for ds in datasets:
                current_models = self._get_models(exp_name, ds, models)
                if not current_models:
                    print(f"No models detected for dataset '{ds}', skipping")
                    continue
                print(f"\nDataset: {ds}")
                for model in current_models:
                    print(f" - Plotting model: {model}")
                    self._plot_metric_for_model(exp_name, ds, model, metric_type, taus_to_plot)

        print(f"\n--- Analysis complete. Output directory: {os.path.join(self.output_base_dir, exp_name)} ---")
