import os
from pathlib import Path
from datetime import datetime
import pandas as pd
from collections import defaultdict
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
from typing import List, Dict, Tuple


def get_latest_results_file(results_dir, pattern="rule_recall_summary_*.pkl"):
    """
    Find the latest results file in a directory based on timestamp.
    """
    result_files = list(Path(results_dir).glob(pattern))
    if not result_files:
        print(f"Warning: No file matching {pattern} was found in {results_dir}")
        return None

    def extract_timestamp(f):
        try:
            timestamp_str = f.stem.split('_')[-1]
            if len(timestamp_str) == 6:  # Just time (HHMMSS)
                return datetime.strptime(timestamp_str, "%H%M%S")
            else:
                date_part, time_part = timestamp_str.split('_')
                return datetime.strptime(f"{date_part}-{time_part}", "%Y%m%d-%H%M%S")
        except ValueError as e:
            print(f"Warning: Could not parse timestamp from {f.name}: {e}")
            return datetime.min

    latest_file = max(result_files, key=extract_timestamp)
    if len(result_files) > 1:
        print(f''' \n  Among {result_files} , we have latest file  {latest_file}.''')
    return latest_file


def collect_latest_reports(example_dirs):
    """
    Given a list of example directories, return a dict:
    {dirname: (latest_global_report, latest_detailed_report)}
    """
    latest_reports = {}
    for d in example_dirs:
        results_dir = Path(d) / "results"
        latest_summary = get_latest_results_file(results_dir, pattern="rule_recall_summary_*.pkl")
        latest_detailed = get_latest_results_file(results_dir, pattern="detailed_report_*.pkl")
        latest_reports[d] = (latest_summary, latest_detailed)
    return latest_reports


def update_ci_in_global_reports(example_dirs):
    """
    Fix any missing CI columns in global reports by computing them from SD.
    """
    for d in example_dirs:
        results_dir = Path(d) / "results"
        latest_summary_file = get_latest_results_file(results_dir, pattern="rule_recall_summary_*.pkl")
        if latest_summary_file is None:
            print(f'Warning! NO rule_recall_summary in {d}')
            continue

        df = pd.read_pickle(latest_summary_file)
        if df.empty:
            print(f'Warning! EMPTY rule_recall_summary in {d}')
            continue

        # Estimate n from matching detailed_report
        latest_detailed_file = get_latest_results_file(results_dir, pattern="detailed_report_*.pkl")
        if latest_detailed_file is None:
            print(f'Warning! NO detailed_report in {d}')
            continue

        n = len(pd.read_pickle(latest_detailed_file)) if latest_detailed_file else 1

        if "rule_recall_precision_ci" not in df.columns and "rule_recall_precision_sd" in df.columns:
            df["rule_recall_precision_ci"] = 1.96 * df["rule_recall_precision_sd"] / (n ** 0.5)
        if "rule_recall_recall_ci" not in df.columns and "rule_recall_recall_sd" in df.columns:
            df["rule_recall_recall_ci"] = 1.96 * df["rule_recall_recall_sd"] / (n ** 0.5)
        if "rule_recall_fscore_ci" not in df.columns and "rule_recall_fscore_sd" in df.columns:
            df["rule_recall_fscore_ci"] = 1.96 * df["rule_recall_fscore_sd"] / (n ** 0.5)

        df.to_pickle(latest_summary_file)
        print(f"Updated CI columns in: {latest_summary_file}")


def merge_detailed_reports_and_summarize(new_dir_name, example_dirs):
    """
    Stack latest detailed reports from all example_dirs into one, save to new_dir,
    and compute a new global summary from it.
    """
    from math import sqrt

    os.makedirs(new_dir_name, exist_ok=True)
    all_dfs = []

    for d in example_dirs:
        results_dir = Path(d) / "results"
        detailed_path = get_latest_results_file(results_dir, pattern="detailed_report_*.pkl")
        if detailed_path:
            df = pd.read_pickle(detailed_path)
            df["source_dir"] = d
            all_dfs.append(df)

    if not all_dfs:
        print("No detailed reports found.")
        return

    merged_df = pd.concat(all_dfs, ignore_index=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    merged_path = Path(new_dir_name) / f"detailed_report_{timestamp}.pkl"
    merged_df.to_pickle(merged_path)

    # Compute global summary
    prec_mean = merged_df["precision"].mean()
    prec_sd = merged_df["precision"].std()
    recall_mean = merged_df["recall"].mean()
    recall_sd = merged_df["recall"].std()
    f1_mean = merged_df["f_score"].mean()
    f1_sd = merged_df["f_score"].std()
    success_rate = merged_df["successful_query_completion"].mean()
    success_ci = 1.96 * sqrt(success_rate * (1 - success_rate) / len(merged_df))

    summary = pd.DataFrame([{
        "rule_recall_precision_mean": prec_mean,
        "rule_recall_precision_sd": prec_sd,
        "rule_recall_precision_ci": 1.96 * prec_sd / (len(merged_df) ** 0.5),
        "rule_recall_recall_mean": recall_mean,
        "rule_recall_recall_sd": recall_sd,
        "rule_recall_recall_ci": 1.96 * recall_sd / (len(merged_df) ** 0.5),
        "rule_recall_fscore_mean": f1_mean,
        "rule_recall_fscore_sd": f1_sd,
        "rule_recall_fscore_ci": 1.96 * f1_sd / (len(merged_df) ** 0.5),
        "query_completion_success_rate": success_rate,
        "query_completion_success_rate_ci": success_ci,
    }])

    summary_path = Path(new_dir_name) / f"rule_recall_summary_{timestamp}.pkl"
    summary.to_pickle(summary_path)

    print(f"Saved merged detailed report: {merged_path}")
    print(f"Saved global summary: {summary_path}")


def augment_df_with_original_columns(df_path, mother_dir, column_list):
    """
    Augments the input DataFrame with additional columns from original CSV files in mother_dir.
    
    Args:
        df_path (str): Path to pickle file containing the input DataFrame
        mother_dir (str): Directory containing the original CSV files
        column_list (list): List of columns to fetch from the original CSVs
        
    Returns:
        pd.DataFrame: Augmented DataFrame with additional columns from original files
    """
    # Load the input DataFrame from pickle file
    try:
        with open(df_path, 'rb') as f:
            df = pickle.load(f)
    except Exception as e:
        raise ValueError(f"Could not load DataFrame from {df_path}: {e}")
    
    # Ensure required columns exist in input DataFrame
    required_columns = ['source_file', 'query_edge', 'query_relation', 'story_index']
    for col in required_columns:
        if col not in df.columns:
            raise ValueError(f"Input DataFrame missing required column: {col}")
    
    # Ensure column_list doesn't include columns already in df (to avoid duplicates)
    column_list = [col for col in column_list if col not in df.columns]
    if not column_list:
        print("All requested columns already exist in input DataFrame")
        return df
    
    # Normalize key fields in input DataFrame
    df['story_index'] = df['story_index'].astype(int)
    df['query_edge'] = df['query_edge'].astype(str)
    df['query_relation'] = df['query_relation'].astype(str)
    
    # Create mapping from source_file to set of keys we need to look up
    file_to_keys = defaultdict(set)
    for _, row in df.iterrows():
        file_to_keys[row['source_file']].add((
            row['query_edge'],
            row['query_relation'],
            row['story_index']
        ))
    
    # Dictionary to store recovered data (keyed by composite key)
    recovered_data = defaultdict(dict)
    
    # Process each source file
    for file_name, keys in file_to_keys.items():
        full_path = os.path.join(mother_dir, file_name)
        print(f"Processing {file_name} with {len(keys)} keys...")
        
        try:
            for chunk in pd.read_csv(full_path, chunksize=8000):
                # Normalize key fields in chunk
                chunk['query_edge'] = chunk['query_edge'].astype(str)
                chunk['query_relation'] = chunk['query_relation'].astype(str)
                chunk['story_index'] = chunk['story_index'].astype(int)
                
                # Create composite key for filtering
                chunk['__key__'] = list(zip(
                    chunk['query_edge'],
                    chunk['query_relation'],
                    chunk['story_index']
                ))
                
                # Filter for matching rows
                matching_chunk = chunk[chunk['__key__'].isin(keys)]
                
                if not matching_chunk.empty:
                    # Store the requested columns for each matching row
                    for _, row in matching_chunk.iterrows():
                        key = (file_name, row['query_edge'], row['query_relation'], row['story_index'])
                        for col in column_list:
                            if col in row:
                                recovered_data[key][col] = row[col]
        
        except Exception as e:
            print(f"\033[91m[ERROR]\033[0m Could not process {file_name}: {e}")
    
    # Convert recovered data to DataFrame
    if not recovered_data:
        print("\033[91m[WARNING]\033[0m No matching rows found in original files")
        return df
    
    # Create DataFrame from recovered data
    recovered_df = pd.DataFrame.from_dict(recovered_data, orient='index')
    recovered_df.index = pd.MultiIndex.from_tuples(recovered_df.index, names=['source_file', 'query_edge', 'query_relation', 'story_index'])
    recovered_df.reset_index(inplace=True)
    
    # Merge with original DataFrame
    result_df = pd.merge(
        df,
        recovered_df,
        on=['source_file', 'query_edge', 'query_relation', 'story_index'],
        how='left'
    )
    
    return result_df


def extract_chain_len_and_opec(directory: str) -> Tuple[int, int]:
    """
    Parses directory name to extract chain length and OPEC value.
    """
    parts = Path(directory).parts
    dir_name = parts[-2] if parts[-1] == "results" else parts[-1]
    opec_part = [p for p in parts if "OPEC" in p][0]
    chain_part = [p for p in parts if "chain_len" in p][0]
    opec_val = int(opec_part.replace("OPEC", "").split("_")[0])
    chain_len = int(chain_part.replace("chain_len", "").replace(".0", ""))
    return opec_val, chain_len


def load_detailed_reports(reports: Dict[str, Tuple[Path,Path]]) -> pd.DataFrame:
    """
    Loads detailed_report pickles, tags with metadata:
    - model_variant: 'O4Mini' if in path, else 'O3'
    - opec_value, chain_len, source_dir, source_filename, group
    """
    rows = []
    for d, (_s, det) in reports.items():
        if det is None: continue
        try:
            df = pd.read_pickle(det)
        except Exception as e:
            print(f"Error reading {det}: {e}")
            continue

        opec_val, chain_len = extract_chain_len_and_opec(d)
        variant = "O4Mini" if "O4Mini" in d else "O3"
        if opec_val == 0:
            group = f"{variant}_OPEC_0"
        elif opec_val == 3:
            group = f"{variant}_OPEC_3"
        else:
            group = f"{variant}_OPEC_g4"

        cols = ['precision','recall','f_score','successful_query_completion',
                'story_index','source_file','query_edge','query_relation']
        sub = df[cols].copy()
        sub['model_variant']  = variant
        sub['opec_value']     = opec_val
        sub['chain_len']      = chain_len
        sub['source_dir']     = d
        sub['source_filename']= det.name
        sub['group']          = group
        rows.append(sub)
    return pd.concat(rows, ignore_index=True)

def sample_balanced_data(df: pd.DataFrame) -> Dict[str,pd.DataFrame]:
    """
    For each group, samples uniformly across chain lengths:
    takes min count per chain_len within that group.
    """
    out = {}
    for grp, gdf in df.groupby('group'):
        counts = gdf['chain_len'].value_counts()
        min_n = counts.min()
        print(f"Sampling {min_n} per chain_len for group {grp}.. here are the counts: {counts} \n ")
        sampled = gdf.groupby('chain_len', group_keys=False).sample(n=min_n, random_state=42)
        out[grp] = sampled
    return out

def compute_summary_stats(grouped: Dict[str,pd.DataFrame], n_bootstrap:int=1000) -> pd.DataFrame:
    """
    Bootstraps mean and 95% CI for each metric in each group.
    """
    rec = []
    for grp, gdf in grouped.items():
        print(f"\nProcessing {grp}.. with {gdf.shape[0]} samples.")
        for m in ['precision','recall','f_score','successful_query_completion']:
            boots = [gdf[m].sample(frac=1, replace=True).mean() for _ in range(n_bootstrap)]
            mu = np.mean(boots)
            lo, hi = np.percentile(boots,[2.5,97.5])
            print(f" {m}: mean={mu:.3f}, CI=[{lo:.3f},{hi:.3f}] ")
            rec.append({'group':grp,'metric':m,'mean':mu,'ci_lower':lo,'ci_upper':hi})
    return pd.DataFrame(rec)

def plot_summary(summary_df: pd.DataFrame, drop_incomplete: bool = False):
    """
    1) successful_query_completion: OPEC on x-axis, hue=model_variant
    2) precision, recall, f_score similarly in 3 subplots.
    """
    sns.set(style="whitegrid")
    # split group into variant & opec
    summary_df[['variant','opec_cond']] = summary_df['group'].str.split('_OPEC_', expand=True)
    summary_df['opec_cond'] = summary_df['opec_cond'].replace({'g4': '>3'})
    variants = ['O3','O4Mini']  # consistent order
    if drop_incomplete:
        # find which (variant, opec_cond) pairs actually exist
        existing = set(map(tuple, summary_df[['variant','opec_cond']].values))
        # keep only opecs for which both variants appear
        valid = [o for o in ['0','3','>3']
                 if all((v, o) in existing for v in variants)]
    else:
        valid = ['0','3','>3']    # original default

    opecs = valid   

    def plot_metric(ax, metric, add_title: bool = True):
        data = summary_df[summary_df['metric']==metric]
        bar_w = 0.8 / len(variants)
        for j, var in enumerate(variants):
            dsub = data[data['variant']==var]
            dsub = dsub[dsub['opec_cond'].isin(opecs)]
            xs = [opecs.index(o) + (j - (len(variants)-1)/2)*bar_w for o in dsub['opec_cond']]
            ax.bar(xs, dsub['mean'], width=bar_w, label=var)
            for x,y,lo,hi in zip(xs, dsub['mean'], dsub['ci_lower'], dsub['ci_upper']):
                ax.errorbar(x, y, yerr=[[y-lo],[hi-y]], fmt='none', capsize=5, color='black')
        ax.set_xticks(range(len(opecs)))
        ax.set_xticklabels(opecs, fontsize=18)
        ax.tick_params(axis='y', labelsize=18)
        if add_title:
            ax.set_title(metric.replace('_',' ').title())
        ax.set_xlabel("OPEC Value", fontsize= 24)
        if metric=='successful_query_completion':
            ax.set_ylabel("Success Rate", fontsize= 24)
        else:
            ax.set_ylabel("Mean")
        ax.legend(title="Model Variant",
                  loc='upper center',
                  bbox_to_anchor=(0.5, 1.15),
                  ncol=len(variants),
                  fontsize=18,
                  title_fontsize=18)

    # -- Plot 1: successful_query_completion (no title) --
    fig, ax = plt.subplots(figsize=(6, 5))
    print(dir(plot_metric))
    plot_metric(ax, 'successful_query_completion', add_title=False)
    plt.tight_layout()
    plt.show()

    # -- Plot 2: precision, recall, f_score (with titles) --
    fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
    for ax, metric in zip(axes, ['precision', 'recall', 'f_score']):
        plot_metric(ax, metric, add_title=True)
    plt.tight_layout()
    plt.show()




