# Standard library imports
import json
import json5
import re
from pathlib import Path
from collections import Counter
from typing import List, Dict
import argparse

# Numerical and plotting libraries
import numpy as np

# Scikit-learn: preprocessing, models, metrics, and model selection
from sklearn.metrics import (
    accuracy_score,
    roc_curve,
    auc,
    f1_score
)
from sklearn.ensemble import RandomForestRegressor
# Statistics
from sklearn.linear_model import LogisticRegression

def get_model_configs():
    """Return regression and classification model configurations"""
    regressors = {
        "RandomForest": RandomForestRegressor(
            n_estimators=500, max_depth=None, min_samples_split=5, min_samples_leaf=2,
            max_features='sqrt', random_state=42, n_jobs=-1
        )}
    classifiers= {
    # Linear / regularized
    "Logistic Regression (L2, balanced)": LogisticRegression(
        penalty="l2", C=0.1, solver="lbfgs", max_iter=10000, class_weight="balanced"
    )}
    return regressors, classifiers

# --------------------------
# Utilities / parsing helpers
# --------------------------

def load_json(file_path: Path):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()
    after_think = None
    if "</think>" in text:
        after_think = text.split("</think>")[-1]
    search_region = after_think if after_think is not None else text
    match = re.search(r"```(?:json)?\s*({.*?})\s*```", search_region, re.DOTALL)
    if not match:
        match = re.search(r"({.*?})", search_region, re.DOTALL)
    if not match:
        return None
    try:
        return json5.loads(match.group(1))
    except Exception:
        return None


def normalize_path_list(x):
    if x is None:
        return []
    if isinstance(x, list):
        return [s.strip() for s in x]
    return [s.strip() for s in re.split(r"\s*->\s*|\s*>\s*|\s*→\s*", str(x)) if s.strip()]


def filter_path_17_node(lst):
    if not lst:
        return lst
    out = list(lst)
    if out and out[-1] == "NSCL-17-10" and len(out) > 1:
        out = out[:-1]
    if out and out[-1] == "NSCL-17-1" and len(out) > 1:
        out = out[:-1]
    return out


def compare_lists(lists, return_final_consistency=True):
    # Clean up special terminal cases
    cleaned_lists = []
    for lst in lists:
        lst_filtered = filter_path_17_node(lst)
        cleaned_lists.append(lst_filtered)

    # --- Path match fraction using Jaccard similarity ---
    sets = [set(lst) for lst in cleaned_lists if lst]
    if sets:
        intersection = set.intersection(*sets) if len(sets) > 1 else sets[0]
        union = set.union(*sets)
        path_match_fraction = len(intersection) / len(union) if union else 1.0
    else:
        print("Error in comparing paths, no paths found")
        return None

    if return_final_consistency:
        final_elements = [lst[-1] for lst in cleaned_lists if lst]
        if not final_elements:
            treatment_match_fraction = 1.0
            treatment_mode = None
        elif len(final_elements) == 1:
            treatment_match_fraction = 1.0
            treatment_mode = final_elements[0]
        else:
            final_counter = Counter(final_elements)
            treatment_mode, most_common_count = final_counter.most_common(1)[0]
            treatment_match_fraction = most_common_count / len(final_elements)

        list_tuples = [tuple(lst) for lst in cleaned_lists]
        list_counter = Counter(list_tuples)
        path_mode = list(list_counter.most_common(1)[0][0]) if list_counter else None

        return path_match_fraction, treatment_match_fraction, path_mode, treatment_mode

    return path_match_fraction

def load_trend_data(heatmap_results:str) -> Dict[str, Dict[str, float]]:
    """
    Load benchmark results and extract both average_treatment_match 
    and average_path_match metrics for each model/benchmark.

    Returns:
        Dict[str, Dict[str, float]]:
            model -> { "benchmark_metric": value, ... }
    """
    reformatted: Dict[str, Dict[str, float]] = {}
    data_path = Path(heatmap_results)

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for entry in data:
        model = entry["model"]
        benchmark = entry["benchmark"]

        # Skip unwanted benchmarks
        if benchmark in {"human_new_prompt", "treatment_aggregation", "path_aggregation"}:
            continue

        if model not in reformatted:
            reformatted[model] = {}

        for metric in ["average_treatment_match", "average_path_match"]:
            value = entry.get(metric)
            if value is None:
                continue  # skip if metric not present

            # Apply penalty for filter benchmarks
            if benchmark in {"path_filter", "treatment_filter"}:
                answered = entry.get("total_patients_matched", 0)
                penalty = answered / 121 if answered else 0
                value *= penalty

            # Unique key: benchmark + metric
            key = f"{benchmark}_{metric}"
            reformatted[model][key] = float(value)
    return reformatted


class AccuracyAnalyzer:
    def __init__(self, annotations_path: Path, heatmap_results:str, patient_id_pool: int = 360):
        self.annotations_path = Path(annotations_path)
        self.patient_id_pool = patient_id_pool
        self.ground_truth = self.load_ground_truth()
        self.model_trend_features = load_trend_data(heatmap_results)

    def load_ground_truth(self) -> Dict[str, List[str]]:
        ground_truth = {}
        for patient_id in range(self.patient_id_pool + 1):
            file_path = self.annotations_path / f"patient_{patient_id}.json"
            if not file_path.exists():
                continue
            try:
                with open(file_path, "r") as f:
                    data = json.load(f)
                if "label" in data:
                    ground_truth[str(patient_id)] = normalize_path_list(data["label"])
            except Exception as e:
                print(f"Error loading GT for patient {patient_id}: {e}")
        return ground_truth

    @staticmethod
    def results_dir_for_model(base_path: Path, model_name: str) -> Path:
        return base_path / "results" / "rollout_results_0815_benchmark" / f"rollout_experiment_{model_name}"

    def load_patient_paths(self, patient_dir: Path, pid: str, num_iterations: int = 10):
        paths = []
        for i in range(num_iterations):
            fpath = patient_dir / f"patient_{pid}_iteration_{i}.json"
            if not fpath.exists():
                continue
            data = load_json(fpath)
            if not data:
                continue
            if "final_path" in data:
                paths.append(normalize_path_list(data["final_path"]))
        return paths

    def collect_rows_for_model(self, results_dir: Path, model_name: str, num_iterations: int, all_model_paths: Dict[str, List[List[str]]] = None):
        """
        all_model_paths: Dict mapping model_name -> list of mode_path per patient_id
        """
        if not results_dir.exists():
            print(f"[WARN] Missing results_dir for {model_name}: {results_dir}")
            return

        for patient_dir in results_dir.glob("patient_*"):
            pid = patient_dir.name.split("_")[1]
            if pid not in self.ground_truth:
                continue
            gt_path = filter_path_17_node(self.ground_truth[pid])

            paths = self.load_patient_paths(patient_dir, pid, num_iterations=num_iterations)
            if not paths:
                continue

            paths_filtered = [tuple(filter_path_17_node(p)) for p in paths if p]
            if not paths_filtered:
                continue
            counter = Counter(paths_filtered)
            mode_path, mode_count = counter.most_common(1)[0]
            mode_path = list(mode_path)

            path_match_fraction_prediction, treatment_match_fraction_prediction, _, _ = compare_lists(paths_filtered)
            path_score_target, treatment_match_target, _, _ = compare_lists([mode_path, gt_path])

            mode_frac = mode_count / max(1, num_iterations)

            # ---- NEW AGGREGATED FEATURE ----
            aggregated_score = None
            if all_model_paths is not None:
                match_count = 0
                total_models = 0
                for other_model, patient_paths in all_model_paths.items():
                    if pid in patient_paths:
                        other_mode_path = patient_paths[pid]
                        total_models += 1
                        if other_mode_path[-1] == mode_path[-1]:
                            match_count += 1
                aggregated_score = match_count / max(1, total_models) if total_models > 0 else 0.0
            # --------------------------------
            feats = {
                "path_match_fraction": path_match_fraction_prediction,
                "treatment_match_fraction": treatment_match_fraction_prediction,
                "mode_frac": mode_frac,
                "aggregated": aggregated_score,
                "num_iterations": num_iterations,
                "model_name": model_name,
                "pid": pid,
                **self.model_trend_features.get(model_name, {})
            }

            yield feats, (path_score_target, int(treatment_match_target)), pid, model_name

    def build_dataset(self, base_path: Path, all_model_names: List[str], target_model_names: List[str], feature_set: str):
        all_model_paths = {}
        for mname in all_model_names:
            rdir = self.results_dir_for_model(base_path, mname)
            patient_paths = {}
            n_iter = 10
            for patient_dir in rdir.glob("patient_*"):
                pid = patient_dir.name.split("_")[1]
                paths = self.load_patient_paths(patient_dir, pid, num_iterations=n_iter)
                if not paths:
                    continue
                paths_filtered = [tuple(filter_path_17_node(p)) for p in paths if p]
                if not paths_filtered:
                    continue
                counter = Counter(paths_filtered)
                mode_path, _ = counter.most_common(1)[0]
                patient_paths[pid] = list(mode_path)
            all_model_paths[mname] = patient_paths
        
        rows = []
        for model_name in target_model_names:
            print("target model:", model_name)
            n_iter = 10
            rdir = self.results_dir_for_model(base_path, model_name)
            for row in self.collect_rows_for_model(rdir, model_name, n_iter, all_model_paths):
                rows.append(row)

        if not rows:
            print("No data gathered from the selected models.")
            return None

        ordered_models = list(dict.fromkeys([m for *_rest, m in [(None, None, None, r[3]) for r in rows]]))
        X_list, y_path, y_treat, groups, model_labels = [], [], [], [], []
        for feats, (ps, tm), pid, mname in rows:
            # Select features based on feature_set parameter
            if feature_set == "base":
                vec = [
                    feats["path_match_fraction"],
                    feats["treatment_match_fraction"],
                ]
            elif feature_set == "base_aggregated":
                vec = [
                    feats["path_match_fraction"],
                    feats["treatment_match_fraction"],
                    feats["aggregated"] if feats["aggregated"] is not None else 0.0,
                ]
            elif feature_set == "trend_only":
                vec = list(self.model_trend_features.get(mname, {}).values())
            elif feature_set == "aggregated_only":
                vec = [
                    feats["aggregated"] if feats["aggregated"] is not None else 0.0,
                ]
            elif feature_set=="internal":
                internal_extra_features = ["unstructured_average_treatment_match", "unstructured_average_path_match","structured_average_treatment_match",
                                            "structured_average_path_match", "path_filter_average_treatment_match", "path_filter_average_path_match", "treatment_filter_average_treatment_match", 
                                            "treatment_filter_average_path_match"]
                vec = [
                    feats["path_match_fraction"],
                    feats["treatment_match_fraction"], 
                ]
                
                # Ensure consistent ordering and handle missing features
                model_feats = self.model_trend_features.get(mname, {})
                for feature_name in internal_extra_features:
                    vec.append(model_feats.get(feature_name, 0.0))  # Use 0.0 as default
            elif feature_set == "all":
                base_feats = [
                    feats["path_match_fraction"],
                    feats["treatment_match_fraction"],
                    feats["aggregated"] if feats["aggregated"] is not None else 0.0,
                ]
                vec = base_feats + list(self.model_trend_features.get(mname, {}).values())
            else:
                raise ValueError(f"Unknown feature_set: {feature_set}")

            X_list.append(vec)
            y_path.append(ps)
            y_treat.append(float(tm))
            groups.append(pid)
            model_labels.append(mname)

        X = np.asarray(X_list, dtype=float)
        y_path = np.asarray(y_path, dtype=float)
        y_treat = np.asarray(y_treat, dtype=int)  # classification
        groups = np.asarray(groups)
        return X, y_path, y_treat, groups, ordered_models, model_labels

    def build_train_test_split(self, base_path: Path, train_models: List[str], test_model: str, 
                              train_split: float, feature_set: str = "all"):
        """
        Build train/test datasets where train comes from train_models and 
        test comes from a single test_model. Uses train_split for patient IDs.
        """
        all_models = train_models #+ [test_model]
        ds_train = self.build_dataset(base_path, all_models, train_models, feature_set)
        print("TRAINING DATASTE BUILT")
        ds_test = self.build_dataset(base_path, all_models, [test_model], feature_set)
        print("TESTING DS BUILT")

        if ds_train is None or ds_test is None:
            return None

        X_train, y_path_train, y_treat_train, groups_train, ordered_models_train, model_labels_train = ds_train
        X_test, y_path_test, y_treat_test, groups_test, ordered_models_test, model_labels_test = ds_test

        # Combine all patient IDs to get unique patients
        all_groups = np.concatenate([groups_train, groups_test])
        unique_patients = np.unique(all_groups)

        # Shuffle and assign
        np.random.seed(42)
        shuffled_patients = np.random.permutation(unique_patients)

        # Use train_split for train patients, rest for test
        n_train_patients = int(len(shuffled_patients) * train_split)
        train_patients = set(shuffled_patients[:n_train_patients])
        test_patients = set(shuffled_patients[n_train_patients:])

        # Remove any overlapping patients
        train_patients = train_patients - test_patients
        test_patients = test_patients - train_patients

        # Create masks
        train_mask = np.isin(groups_train, list(train_patients))
        test_mask = np.isin(groups_test, list(test_patients))

        # Subset the data
        X_train_split = X_train[train_mask]
        y_path_train_split = y_path_train[train_mask]
        y_treat_train_split = y_treat_train[train_mask]
        groups_train_split = groups_train[train_mask]

        X_test_split = X_test[test_mask]
        y_path_test_split = y_path_test[test_mask]
        y_treat_test_split = y_treat_test[test_mask]
        groups_test_split = groups_test[test_mask]

        return (X_train_split, y_path_train_split, y_treat_train_split, groups_train_split), (X_test_split, y_path_test_split, y_treat_test_split, groups_test_split)


class GridSearchAnalyzer(AccuracyAnalyzer):
    """Extended analyzer for comprehensive grid search of hyperparameters"""
    
    def __init__(self, annotations_path: Path, heatmap_results: str, patient_id_pool: int = 360, test_model: str = "deepseek"):
        super().__init__(annotations_path, heatmap_results, patient_id_pool)
        self.test_model = test_model
        
    
    def evaluate_single_config(self, test_model: str, train_size: float = 1.0, feature_set: str = "base",
                                model_names: list = ["gpt-4.1", "gpt-5", "gpt-5-med", "o4-mini", "o3", 'deepseek', 'gpt-5-high', 'llama']) -> Dict:
        """Evaluate a single configuration and return metrics including ROC curve data"""
        
        _, classifiers = get_model_configs()
        clf = list(classifiers.values())[0]  # Use first classifier

        base_path = BASE_PATH
        ordered_models = model_names

        # Identify train/test models
        train_models = [m for m in ordered_models if m != test_model]

        # Build proper train/test split using your function
        train_split = train_size  # fraction of patients for training
        split_data = self.build_train_test_split(base_path, train_models, test_model, train_split, feature_set=feature_set)
        if split_data is None:
            return {
                "auroc": np.nan, "accuracy": np.nan, "f1": np.nan, "n_samples": 0,
                "fpr": None, "tpr": None, "roc_thresholds": None
            }

        (X_train, y_train, y_treat_train, groups_train), \
        (X_test, y_test, y_treat_test, groups_test) = split_data

        if len(y_treat_test) == 0 or len(np.unique(y_treat_test)) < 2:
            return {
                "auroc": np.nan, "accuracy": np.nan, "f1": np.nan, "n_samples": len(y_treat_test),
                "fpr": None, "tpr": None, "roc_thresholds": None
            }

        try:
            clf.fit(X_train, y_treat_train)

            # Get predictions and scores
            y_pred = clf.predict(X_test)
            
            if hasattr(clf, "predict_proba"):
                y_scores = clf.predict_proba(X_test)[:, 1]
            elif hasattr(clf, "decision_function"):
                y_scores = clf.decision_function(X_test)
            else:
                y_scores = y_pred.astype(float)

            # Compute metrics
            if len(np.unique(y_treat_test)) > 1:
                fpr, tpr, thresholds = roc_curve(y_treat_test, y_scores)
                auroc = auc(fpr, tpr)
            else:
                fpr, tpr, thresholds = None, None, None
                auroc = np.nan
                
            accuracy = accuracy_score(y_treat_test, y_pred)
            f1 = f1_score(y_treat_test, y_pred, average="binary")
            n_samples = len(y_treat_test)

        except Exception as e:
            print(f"Holdout evaluation failed: {e}")
            return {
                "auroc": np.nan, "accuracy": np.nan, "f1": np.nan, "n_samples": len(y_treat_test),
                "fpr": None, "tpr": None, "roc_thresholds": None
            }
            
        return {
            "auroc": auroc, "accuracy": accuracy, "f1": f1, "n_samples": n_samples,
            "fpr": fpr, "tpr": tpr, "roc_thresholds": thresholds
        }
        
    def run_grid_search(self, base_path: Path, model_names: List[str], 
                       train_sizes: List[float] = [0.2, 0.4, 0.6, 0.8, 1.0],
                       feature_sets: List[str] = ["base", "base_aggregated", "trend_only", "all"],
                       methods: List[str] = ["cv", "holdout"],
                       outdir: Path = Path("./grid_search_results"), 
                       test_model: str = "deepseek"):
        """Run comprehensive grid search"""
        
        print("="*60)
        print("RUNNING COMPREHENSIVE GRID SEARCH")
        print("="*60)
        
        results = []
        total_configs = len(train_sizes) * len(feature_sets) * len(methods)
        config_count = 0
        
        for train_size in train_sizes:
            for feature_set in feature_sets:
                for method in methods:
                    config_count += 1
                    print(f"\nConfig {config_count}/{total_configs}: "
                          f"train_size={train_size}, features={feature_set}, method={method}")
                    
                    try:
                        
                        # Evaluate treatment accuracy (classification)
                        if method == "cv":# Build dataset with specified feature set
                            ds = self.build_dataset(base_path, model_names, model_names, feature_set)
                            if ds is None:
                                print("  -> No data, skipping")
                                continue
                            
                            X, y_path, y_treat, groups, ordered_models, model_labels = ds
                        else:
                            X, y_treat, groups = None, None, None

                        metrics = self.evaluate_single_config(
                            test_model, float(train_size), feature_set=feature_set, model_names=model_names
                        )
                        
                        result = { 
                            "train_size": train_size,
                            "feature_set": feature_set,
                            "method": method,
                            "n_features": X.shape[1] if X is not None else 0,
                            **metrics
                        }
                        results.append(result)
                        
                        print(f"  -> AUROC: {metrics['auroc']:.3f}, "
                              f"Acc: {metrics['accuracy']:.3f}, "
                              f"F1: {metrics['f1']:.3f}, "
                              f"N: {metrics['n_samples']}")
                        
                    except Exception as e:
                        print(f"  -> Error: {e}")
                        continue
        
        # Save results
        outdir.mkdir(parents=True, exist_ok=True)
        results_df, roc_data = self._results_to_dataframe(results)
        data_serializable = {
            int_key: {str_key: array.tolist() if hasattr(array, 'tolist') else list(array) 
                    for str_key, array in inner_dict.items()}
            for int_key, inner_dict in roc_data.items()
        }

        # Save
        with open(f"{outdir}/roc_{test_model}.json", 'w') as f:
            json.dump(data_serializable, f)
        results_df.to_csv(outdir / "grid_search_results.csv", index=False)
        
        # Create comprehensive plots
        self._plot_grid_results(results_df, outdir)
        
        return results_df


# --------------------------
# Extended CLI
# --------------------------

if __name__ == "__main__":
    BASE_PATH = Path("./CancerGUIDE")
    ANNOTATIONS_PATH = BASE_PATH / "data" / "benchmarks" / "human_annotations"

    parser = argparse.ArgumentParser(description="Extended accuracy prediction with comprehensive grid search.")
    parser.add_argument("--mode", choices=["cv", "split", "grid_search"], required=True,
                        help="Evaluation mode: 'cv', 'split', or 'grid_search'")
    
    # For CV mode
    parser.add_argument("--models", nargs="+",
                        help='List of model names for CV mode')
    parser.add_argument("--cv-splits", type=int, default=5,
                        help="Number of GroupKFold splits for CV mode.")
    
    # For split mode
    parser.add_argument("--train-models", nargs="+",
                        help='List of model names for training in split mode')
    parser.add_argument("--testmodel", type=str,
                        help='Single model name for testing in split mode')
    
    # For grid search mode
    parser.add_argument("--train-sizes", nargs="+", type=float, 
                        default=[0.2, 0.4, 0.6, 0.8, 1.0],
                        help="Training sizes to test in grid search")
    parser.add_argument("--feature-sets", nargs="+", 
                        default=["base", "base_aggregated", "trend_only", "all", "aggregated_only"],
                        help="Feature sets to test in grid search")
    parser.add_argument("--methods", nargs="+", 
                        choices=["cv", "holdout"],
                        default=["cv", "holdout"],
                        help="Evaluation methods to test in grid search")
    parser.add_argument("--heatmap_results", type=str, default="./CancerGUIDE/results/heatmap_results.json")

    # Common arguments
    parser.add_argument("--outdir", type=Path, default=Path("./pred_eval_outputs"),
                        help="Output directory for plots and results.")
    parser.add_argument("--train-split", type=float, default=0.7)
    parser.add_argument("--feature-set", choices=["base", "base_aggregated", "trend_only", "all"],
                        default="all", help="Feature set to use for cv/split modes")
    
    args = parser.parse_args()

    if args.mode == "grid_search":
        if not args.models:
            parser.error("--models is required for grid search mode")
        
        analyzer = GridSearchAnalyzer(annotations_path=ANNOTATIONS_PATH, heatmap_results = args.heatmap_results, patient_id_pool=360)
        
        print(f"Running grid search with:")
        print(f"  Models: {args.models}")
        print(f"  Training sizes: {args.train_sizes}")
        print(f"  Feature sets: {args.feature_sets}")
        print(f"  Methods: {args.methods}")
        
        # Run comprehensive grid search
        results_df = analyzer.run_grid_search(
            base_path=BASE_PATH,
            model_names=args.models,
            train_sizes=args.train_sizes,
            feature_sets=args.feature_sets,
            methods=args.methods,
            outdir=args.outdir,
            test_model=args.testmodel if args.testmodel else "deepseek"
        )
    