import argparse
import yaml
import json
import sys
import time
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import pickle

sys.path.append(str(Path(__file__).parent))
from incremental_summarizer import IncrementalSummarizer


def load_progress(progress_file):
    return (
        pickle.load(open(progress_file, "rb"))
        if progress_file.exists()
        else {"completed": set()}
    )


def save_progress(progress, progress_file):
    pickle.dump(progress, open(progress_file, "wb"))


def extract_notes_dict(row, note_columns):
    combined_notes = {}
    for column in note_columns:
        if pd.isna(row.get(column)):
            continue
        try:
            notes_raw = json.loads(row[column])
            for note_id, note_data in notes_raw.items():
                text = (
                    note_data[1]
                    if isinstance(note_data, list) and len(note_data) >= 2
                    else str(note_data)
                )
                combined_notes[note_id] = text
        except Exception as e:
            print(f"Error processing {column}: {e}")
    return combined_notes or None


def run_summarizer(notes_dict, model, method, entity_name, config):
    try:
        summarizer = IncrementalSummarizer(config, model)
        start_time = time.time()
        result = summarizer.summarize(
            method=method, notes=notes_dict, entity_name=entity_name
        )
        execution_time = time.time() - start_time

        if len(result) == 2:
            return {
                "summary": json.dumps(result[0])
                if isinstance(result[0], dict)
                else result[0],
                "partial_summaries": json.dumps(result[1])
                if isinstance(result[1], dict)
                else result[1],
                "execution_time": execution_time,
            }
        return {
            "summary": json.dumps(result) if isinstance(result, dict) else result,
            "partial_summaries": None,
            "execution_time": execution_time,
        }
    except Exception as e:
        print(f"Error with {model['model_name']}/{method}: {e}")
        return {
            "error": str(e),
            "execution_time": None,
            "summary": None,
            "partial_summaries": None,
        }


def get_clean_model_name(model_name):
    return model_name.replace(":", "_").replace("-", "_")


def add_result_columns(df, models, methods):
    for model in models:
        model_clean = get_clean_model_name(model["model_name"])
        for method in methods:
            col_base = f"{model_clean}_{method}"
            for suffix in ["", "_time", "_partial", "_merged"]:
                col_name = f"{col_base}{suffix}"
                if col_name not in df.columns and (
                    suffix != "_partial"
                    and suffix != "_merged"
                    or method == "generate_merge"
                ):
                    df[col_name] = None


def load_existing_output(output_path):
    """Load existing output file if it exists."""
    if output_path.exists():
        print(f"Loading existing output file: {output_path}")
        return pd.read_csv(output_path)
    else:
        print("No existing output file found. Starting fresh.")
        return None


def find_incomplete_rows(df, models, methods):
    """Find rows with NaN values in the intended output columns."""
    incomplete_rows = set()
    for model in models:
        model_clean = get_clean_model_name(model["model_name"])
        for method in methods:
            col_base = f"{model_clean}_{method}"
            if col_base in df.columns:
                incomplete_rows.update(df[df[col_base].isna()].index)
    return sorted(incomplete_rows)


def main():
    parser = argparse.ArgumentParser(description="Run summarizers on all data")
    parser.add_argument("config", help="Path to configuration file")
    parser.add_argument("--no-resume", action="store_true", help="Start fresh")

    args = parser.parse_args()
    config = yaml.safe_load(open(args.config))

    output_path = Path(config["output"])

    # Load existing output file if it exists, or start fresh
    df = load_existing_output(output_path)
    if df is None:
        print("No output file found, starting from scratch")
        df = pd.read_csv(
            Path(config["dataframe_path"])
        )  # Load input file if no output exists
        add_result_columns(df, config["llms"], config["methods"])  # Add output columns
    else:
        print(f"Output file found: {str(output_path)}")

    print(f"Processing {len(df)} rows")

    incomplete_rows = find_incomplete_rows(df, config["llms"], config["methods"])
    print(f"Found {len(incomplete_rows)} rows with incomplete results.")

    try:
        for idx in tqdm(
            incomplete_rows, total=len(incomplete_rows), desc="Processing rows"
        ):
            row = df.loc[idx]

            notes_dict = extract_notes_dict(row, config["note_columns"])
            if not notes_dict:
                continue

            for model in config["llms"]:
                model_clean = get_clean_model_name(model["model_name"])
                for method in config["methods"]:
                    col_base = f"{model_clean}_{method}"

                    if pd.notna(df.loc[idx, col_base]):
                        continue

                    result = run_summarizer(
                        notes_dict, model, method, f"PATIENT_{idx}", config
                    )

                    df.loc[idx, col_base] = str(result.get("summary"))
                    df.loc[idx, f"{col_base}_time"] = result.get("execution_time")

                    # if result.get("partial_summaries"):
                    #     df.at[idx, f"{col_base}_partial"] = result.get(
                    #         "partial_summaries"
                    #     )

            df.to_csv(output_path, index=False)

        print(f"Complete! Results saved to {output_path}")

    except KeyboardInterrupt:
        df.to_csv(output_path, index=False)
        print("Interrupted - progress saved")


if __name__ == "__main__":
    main()
