# experiments/analysis_modules/case_analyzer.py

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import re
from sklearn.decomposition import PCA  #Introducing PCA
from .main_comparison import MainComparisonAnalyzer #Inherit MainComparisonAnalyzer to reuse its path and shift_state logic

class CaseAnalyzer(MainComparisonAnalyzer):
    """
    Extend the analyzer specifically for loading, merging, and plotting case study tracks for multiple models.
    
    The analyzer can:
    - Load the 'case_study_trajectories.pkl` file from the results directory of the specified experiment, dataset, and model.
    - Combine trajectory data from the same patient from different models.
    - Generate comparison charts for different `shift_state` and `metric_type` (episode_wise/step_wise).
    - * * * New: If the model has prediction data, its predicted trajectory is represented by a dotted line in the same color. * * *
    - * * * New: If outcome is multidimensional, use PCA to descend to one dimension for drawing. * * *
    - Save the generated images to a structured output directory.
    """

    def __init__(self, results_base_path: str = "results", output_base_dir: str = "analysis_output"):
        """Initializes the CaseAnalyzer, inheriting all properties and methods of the MainComparisonAnalyzer."""
        super().__init__(results_base_path, output_base_dir)
        print("CaseAnalyzer initialized for trajectory plotting.")

    def load_and_merge_case_studies(self, exp_name: str, dataset: str, metric_type: str, shift_state: str, models_to_merge: list) -> dict:
        """
        Load and merge the case study .pkl file under the specified model and settings.
        (This function does not need to be modified)
        """
        merged_data = {}
        print(f"  - Loading and merging for: {dataset} / {metric_type} / {shift_state}")

        for model in models_to_merge:
            param_suffix = shift_state
            base_model_path = os.path.join(self.results_base_path, exp_name, dataset, model)
            if 'gift' in base_model_path.lower():
                potential_path = os.path.join(base_model_path, param_suffix, "raw_results", "case_study_trajectories.pkl")
            else:
                potential_path = os.path.join(base_model_path, param_suffix, "raw_results", metric_type, "case_study_trajectories.pkl")
            print(f"potential_path:{model, potential_path}")
            pkl_path = None
            if os.path.exists(potential_path):
                pkl_path = potential_path
            elif shift_state == 'shift_False':
                default_path = os.path.join(base_model_path, "default", "raw_results", "case_study_trajectories.pkl")
                if os.path.exists(default_path):
                    pkl_path = default_path
            print(pkl_path)
            if not pkl_path:
                continue

            try:
                with open(pkl_path, 'rb') as f:
                    loaded_data = pickle.load(f)

                for patient_id, patient_data in loaded_data.items():
                    if patient_id not in merged_data:
                        merged_data[patient_id] = {
                            'initial_outcome': patient_data['initial_outcome'],
                            'goal': patient_data['goal'],
                            'ground_truth_outcomes': patient_data['ground_truth_outcomes'],
                            'models': {}
                        }
                    
                    merged_data[patient_id]['models'].update(patient_data['models'])

            except Exception as e:
                print(f"    - ERROR: Failed to load or merge data from {pkl_path}: {e}")

        return merged_data

    def plot_outcome_trajectories(self, patient_data: dict, patient_id: int, output_path: str):
        """
        Merged trajectory comparisons were plotted for individual patients.
        If the model contains prediction data, its prediction trajectory is plotted with the same color dotted line.
        New: If the outcome is multidimensional, use PCA to descend to one dimension for drawing.
        """
        ratio = 0.85
        fig, ax = plt.subplots(figsize=(13*ratio, 7*ratio))

        #--- 1. Data extraction and preparation ---
        initial_outcome = np.array(patient_data['initial_outcome'])
        ground_truth = np.array(patient_data['ground_truth_outcomes'])
        goal = np.array(patient_data['goal'])
        
        #Resolve all model data in advance
        parsed_models = {}
        for model_name, model_data in patient_data['models'].items():
            sim_outcomes, pred_outcomes = None, None
            if isinstance(model_data, dict):
                sim_outcomes = model_data.get('simulated_outcome_trajectory')
                pred_outcomes = model_data.get('predicted_outcome_trajectory')
            else:
                sim_outcomes = model_data
            
            if sim_outcomes is not None: sim_outcomes = np.array(sim_outcomes)
            if pred_outcomes is not None: pred_outcomes = np.array(pred_outcomes)
                
            parsed_models[model_name] = {
                'simulated_outcome_trajectory': sim_outcomes,
                'predicted_outcome_trajectory': pred_outcomes
            }

        #--- 2. PCA dimension reduction (if needed) ---
        #Based on ground_truth, check whether it is multi-dimensional data (number of features > 1)
        is_multidimensional = ground_truth.ndim > 1 and ground_truth.shape[-1] > 1
        y_label = 'Outcome'

        if is_multidimensional:
            print(f"  - Detected multi-dimensional outcomes for patient {patient_id}. Applying PCA.")
            y_label = 'Outcome (1st Principal Component)'
            pca = PCA(n_components=1)
            
            #Collect all data points to fit PCA
            all_data_points = []
            num_features = initial_outcome.shape[-1]

            all_data_points.append(initial_outcome.reshape(1, -1))
            if goal.ndim > 0 and goal.size > 1: all_data_points.append(goal.reshape(1, -1))
            if ground_truth.ndim > 1: all_data_points.append(ground_truth)
            
            for model_name, data in parsed_models.items():
                sim = data['simulated_outcome_trajectory']
                pred = data['predicted_outcome_trajectory']
                if sim is not None and sim.size > 1: all_data_points.append(sim.reshape(-1, num_features))
                if pred is not None and pred.size > 1: all_data_points.append(pred.reshape(-1, num_features))
            
            if not all_data_points:
                print(f"    - WARNING: No multi-dimensional data points found for PCA for patient {patient_id}. Skipping plot.")
                plt.close(fig)
                return
                
            combined_data = np.vstack(all_data_points)
            pca.fit(combined_data)

            #Convert all data
            initial_outcome = pca.transform(initial_outcome.reshape(1, -1)).flatten()[0]
            ground_truth = pca.transform(ground_truth).flatten()
            if goal.ndim > 0 and goal.size > 1:
                goal = pca.transform(goal.reshape(1, -1)).flatten()[0]
            else:
                goal = goal.item()
                
            for model_name, data in parsed_models.items():
                sim = data['simulated_outcome_trajectory']
                if sim is not None and sim.size > 1:
                    data['simulated_outcome_trajectory'] = pca.transform(sim.reshape(-1, num_features)).flatten()

                pred = data['predicted_outcome_trajectory']
                if pred is not None and pred.size > 1:
                    data['predicted_outcome_trajectory'] = pca.transform(pred.reshape(-1, num_features)).flatten()
        else:
            #If not multidimensional, make sure the data is a scalar/one-dimensional array
            initial_outcome = initial_outcome.item() if initial_outcome.size == 1 else initial_outcome
            goal = goal.item() if goal.size == 1 else goal
            ground_truth = ground_truth.flatten()

        #--- 3. Drawing ---
        if ground_truth.ndim == 0: ground_truth = np.array([ground_truth])
        num_steps = len(ground_truth)

        y_true_with_initial = np.insert(ground_truth, 0, initial_outcome)
        ax.plot(range(num_steps + 1), y_true_with_initial, 
                label='Ground Truth', linestyle='--', color='black', marker='x', zorder=10)

        for i, (model_name, model_data) in enumerate(parsed_models.items()):
            color = self.plot_colors[i % len(self.plot_colors)]
            marker = self.markers[i % len(self.markers)]
            
            simulated_outcomes = model_data['simulated_outcome_trajectory']
            predicted_outcomes = model_data['predicted_outcome_trajectory']
            
            print(f"model:{model_name}, simulated_outcomes:{simulated_outcomes}")
            
            if simulated_outcomes is None: continue
            
            simulated_outcomes = np.array(simulated_outcomes).flatten()
            if simulated_outcomes.ndim == 0: simulated_outcomes = np.array([simulated_outcomes])
            
            y_simulated_with_initial = np.insert(simulated_outcomes, 0, initial_outcome)
            x_range_simulated = range(len(y_simulated_with_initial))
            
            ax.plot(x_range_simulated, y_simulated_with_initial, 
                    label=model_name.upper(), marker=marker, color=color, linestyle='-', alpha=0.9)

            if predicted_outcomes is not None and predicted_outcomes.size > 0:
                predicted_outcomes = np.array(predicted_outcomes).flatten()
                if predicted_outcomes.ndim == 0: predicted_outcomes = np.array([predicted_outcomes])
                
                num_pred_steps = min(len(simulated_outcomes), len(predicted_outcomes))
                
                y_predicted_with_initial = np.insert(predicted_outcomes[:num_pred_steps], 0, initial_outcome)
                x_range_pred = range(len(y_predicted_with_initial))

                ax.plot(x_range_pred, y_predicted_with_initial, color=color, linestyle='--', alpha=0.8)

        ax.axhline(y=goal, color='grey', linestyle=':', label='Goal')
        
        # ax.set_title(f'Outcome Trajectory Comparison for Patient {patient_id}', fontsize=24) # --- MODIFIED: Title removed
        ax.set_xlabel(r'$\tau$', fontsize=33) # --- MODIFIED: Label text and font size
        ax.set_ylabel(y_label, fontsize=25) # --- MODIFIED: Font size
        ax.tick_params(axis='x', which='major', labelsize=24) # --- MODIFIED: Font size for x-axis
        ax.tick_params(axis='y', which='major', labelsize=24) # --- MODIFIED: Font size for y-axis
        ax.set_yticklabels([]) # --- MODIFIED: Remove y-axis tick labels
        
        ax.set_xticks(np.arange(0, num_steps + 1, step=max(1, (num_steps + 1) // 6))) # --- MODIFIED: xticks start from 0
        
        ax.legend(fontsize=24, loc='upper left') # --- MODIFIED: Font size
        ax.grid(True, which='both', linestyle='--', linewidth=0.5)
        plt.tight_layout()
        
        try:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            plt.savefig(output_path, dpi=300, bbox_inches='tight')
        except Exception as e:
            print(f"    - ERROR: Failed to save plot to {output_path}: {e}")
        finally:
            plt.close(fig)

    def plot_all_merged_cases(self, merged_data: dict, exp_name: str, dataset_key: str, metric_type: str, shift_state: str):
        """
        Traverse all combined patient cases and generate and save drawings for each case.
        (This function does not need to be modified, keep it as it is)
        """
        if not merged_data:
            return

        print(f"  - Plotting {len(merged_data)} merged trajectories for {dataset_key}/{metric_type}/{shift_state}...")
        
        base_dir_for_setting = os.path.dirname(self._get_main_comp_output_path(
            exp_name=exp_name,
            output_type='case_study',
            metric_type=metric_type,
            shift_state=shift_state,
            filename='dummy.txt'
        ))
        
        dataset_specific_dir = os.path.join(base_dir_for_setting, dataset_key)
        
        for patient_id, patient_data in merged_data.items():
            if patient_id not in [12, 42]:
                continue
            filename = f"case_study_patient_{patient_id}.pdf"
            output_path = os.path.join(dataset_specific_dir, filename)
            
            self.plot_outcome_trajectories(patient_data, patient_id, output_path)
            if patient_id > 50:
                break

    def run_analysis(self, 
                     exp_name: str, 
                     datasets: list = [], 
                     models_to_merge: list = []):
        """
        Run the full case study analysis process: load, merge, plot.
        (This function does not need to be modified)
        """
        models_to_merge = ['vcip', 'rmsn', 'gift']
        if not models_to_merge:
            models_to_merge = ['vcip', 'ct', 'gift', 'rmsn', 'crn', 'actin']
        
        if not datasets:
            print("No datasets specified, automatically scanning for available datasets...")
            try:
                results_scan = self.scan_experiments_for_metric_type(exp_name, 'episode_wise')
                datasets = [key for key in results_scan.keys() if 'mimic' in key or 'tumor' in key]
                if not datasets:
                    print("Warning: Could not find any 'mimic' or 'tumor' datasets automatically.")
                    return
            except Exception as e:
                print(f"Error during automatic dataset scan: {e}")
                return

        print(f"\n--- Starting Case Study Trajectory Analysis for Experiment: '{exp_name}' ---")
        print(f"Datasets to analyze: {datasets}")
        print(f"Models to merge: {models_to_merge}")

        for dataset in datasets:
            print(f"\n1. Processing Dataset: '{dataset}'")
            for metric_type in self.metrics_types:
                for shift_state in self.shift_states:
                    merged_case_data = self.load_and_merge_case_studies(
                        exp_name=exp_name,
                        dataset=dataset,
                        metric_type=metric_type,
                        shift_state=shift_state,
                        models_to_merge=models_to_merge
                    )

                    self.plot_all_merged_cases(
                        merged_data=merged_case_data,
                        exp_name=exp_name,
                        dataset_key=dataset,
                        metric_type=metric_type,
                        shift_state=shift_state
                    )
        
        print(f"\n--- Case Study Analysis for '{exp_name}' complete. ---")