# analysis_modules/complexity_analyzer.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from .base_analyzer import BaseAnalyzer
from typing import Dict, Optional, List, Any

class ComplexityAnalyzer(BaseAnalyzer):
    """
    Analyze and visualize model complexity to generate high-quality grouped bar charts for publishing.
    - Optimized fill patterns and font sizes for optimal aesthetics and readability.
    - Compatible with black and white printing, the legend is located inside the chart.
    """
    def __init__(self, results_base_path: str = "results", output_base_dir: str = "analysis_output"):
        super().__init__(results_base_path, output_base_dir)
        #--- Modify: Add vcip to the list of models ---
        self.models_to_display = ['rmsn', 'crn', 'ct', 'actin', 'vcip', 'gift']
        
        #--- Optimization: Define colors and denser, cleaner fill patterns ---
        model_keys_upper = [m.upper() for m in self.models_to_display]
        #--- Modify: Add new color for vcip ---
        colors = [self.plot_colors[0], self.plot_colors[3], self.plot_colors[4], self.plot_colors[2], self.plot_colors[5], self.plot_colors[1]]
        #--- Modify: add the specified fill pattern '|' for vcip ---
        #Use repetitive characters to increase density and select linear patterns to avoid visual discomfort
        hatches = ['//', '\\\\', '|/', '-\\', '-', '\|'] 

        self.model_colors = dict(zip(model_keys_upper, colors))
        self.model_hatches = dict(zip(model_keys_upper, hatches))

    def _find_complexity_csv(self, model_base_path: str) -> Optional[str]:
        """Flexibly find 'complexity_info.csv' to work with subdirectory structures."""
        direct_path = os.path.join(model_base_path, "raw_results", "complexity_info.csv")
        if os.path.exists(direct_path):
            return direct_path
        if os.path.isdir(model_base_path):
            for sub_dir in os.listdir(model_base_path):
                sub_dir_path = os.path.join(model_base_path, sub_dir)
                if os.path.isdir(sub_dir_path):
                    nested_path = os.path.join(sub_dir_path, "raw_results", "complexity_info.csv")
                    if os.path.exists(nested_path):
                        return nested_path
        return None

    def scan_and_aggregate_complexity(self, exp_name: str, datasets: List[str]) -> Dict[str, Any]:
        """Scan and aggregate all complexity data."""
        #(This function does not need to be changed, leave as is)
        print("--- Scanning and aggregating all complexity data ---")
        all_data = {ds: {} for ds in datasets}
        for dataset in datasets:
            dataset_path = os.path.join(self.results_base_path, exp_name, dataset)
            if not os.path.isdir(dataset_path):
                print(f"Warning: Dataset directory not found: {dataset_path}"); continue
            print(f"- Processing dataset: {dataset}")
            for model in self.models_to_display:
                model_path = os.path.join(dataset_path, model)
                csv_path = self._find_complexity_csv(model_path)
                if not csv_path:
                    print(f"- Warning: 'complexity_info.csv' not found for model '{model}'"); continue
                try:
                    df = pd.read_csv(csv_path)
                    grouped = df.groupby('optimize_by_step')
                    model_data = {}
                    for name, group in grouped:
                        model_data[name] = {
                            'params_mean': group['params'].mean(), 'params_std': group['params'].std(),
                            'mflops_mean': group['mflops'].max(), 'mflops_std': 0,
                            'train_time_mean': group['train_time'].mean(), 'train_time_std': group['train_time'].std(),
                            'test_time_mean': group['test_time'].mean(), 'test_time_std': group['test_time'].std(),
                        }
                    all_data[dataset][model.upper()] = model_data
                except Exception as e:
                    print(f"- Error: Processing {csv_path} failed: {e}")
        print("---Data aggregation complete---\ n")
        return all_data

    def create_grouped_bar_plots(self, exp_name: str, datasets: List[str]):
        """Based on the aggregated data, a final optimized version of the grouped bar chart is generated."""
        all_complexity_data = self.scan_and_aggregate_complexity(exp_name, datasets)
        
        metrics_info = {
            'params': {'label': 'Kilo Parameters', 'unit_converter': lambda x: x / 1000},
            'mflops': {'label': 'MFLOPs', 'unit_converter': lambda x: x},
            'train_time': {'label': 'Time (Seconds)', 'unit_converter': lambda x: x},
            'test_time': {'label': 'Time (Seconds)', 'unit_converter': lambda x: x},
        }

        for optimize_state in [True, False]:
            print(f"--- Generating chart for optimize_by_step = {optimize_state} ---")
            for metric, info in metrics_info.items():
                fig, ax = plt.subplots(figsize=(8, 6))
                
                n_models = len(self.models_to_display)
                n_datasets = len(datasets)
                bar_width = 0.8 / n_models
                group_width = bar_width * n_models
                dataset_indices = np.arange(n_datasets)

                bottom = 0
                if metric in ['test_time', 'mflops']:
                    bottom = -1

                for i, model_name_upper in enumerate([m.upper() for m in self.models_to_display]):
                    means, stds = [], []
                    for dataset in datasets:
                        data = all_complexity_data.get(dataset, {}).get(model_name_upper, {}).get(optimize_state)
                        if data:
                            means.append(info['unit_converter'](data[f'{metric}_mean']))
                            std_val = data[f'{metric}_std']
                            stds.append(info['unit_converter'](std_val) if pd.notna(std_val) else 0)
                        else: means.append(0); stds.append(0)
                    
                    bar_positions = dataset_indices - group_width / 2 + i * bar_width + bar_width / 2
                    means = [mean - bottom for mean in means]
                    ax.bar(bar_positions, means, width=bar_width, label=model_name_upper, 
                           color=self.model_colors[model_name_upper],
                        #    hatch=self.model_hatches[model_name_upper],
                           yerr=stds, capsize=4, edgecolor='black', linewidth=1.0, bottom=bottom)

                if metric in ['test_time', 'mflops']:
                    ax.set_ylim(bottom=-1)

                ax.set_ylabel(info['label'], fontsize=36)
                ax.set_xticks(dataset_indices)
                xtick_labels = [ds.replace('_', ' ').replace('mimic3 synthetic', 'MIMIC-III').replace('tumor', 'Tumor').replace('gamma=', '($\\gamma$=') + (')' if 'gamma' in ds else '') for ds in datasets]
                ax.set_xticklabels(xtick_labels, fontsize=32, ha='center')
                ax.tick_params(axis='y', labelsize=32)
                ax.grid(axis='y', linestyle='--', alpha=0.6)

                
                
                title_metric_name = metric.replace('_', ' ').title()
                # ax.set_title(f"Model Complexity: {title_metric_name}", fontsize=32)

                # ax.legend(fontsize=25, loc='upper right')
                ax.legend(fontsize=25, loc='upper right')
                
                plt.tight_layout()

                output_directory = os.path.join(self.output_base_dir, exp_name, 'complexity_figures')
                os.makedirs(output_directory, exist_ok=True)
                filename = f"complexity_{metric}_optimize_{str(optimize_state).lower()}.pdf"
                output_path = os.path.join(output_directory, filename)
                
                plt.savefig(output_path, dpi=300)
                plt.close(fig)
                print(f"- Chart saved to: {output_path}")

    def run_analysis(self, exp_name: str = "main_comparison"):
        """Run the full complexity analysis process."""
        datasets_to_analyze = ["mimic3_synthetic", "tumor_gamma=4"]
        self.create_grouped_bar_plots(exp_name, datasets_to_analyze)
        print(f"\ n---Complexity analysis complete. The output is located at '{os.path.join (self.output_base_dir, exp_name,' complexity_figures')} '---")

if __name__ == '__main__':
    analyzer = ComplexityAnalyzer(results_base_path="results", output_base_dir="analysis_output")
    analyzer.run_analysis(exp_name="main_comparison")