import argparse
from pathlib import Path
import re

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import scan_cache_dir

from utils import *

output_dir = Path("output")
data_dir = Path("data")
runs_dir = Path("runs")

all_model_ids = list_all_models_to_test()

###### FALSE BELIEF DATASET ######
fb_data = pd.read_csv(data_dir / "fb_stimuli.csv", delimiter=",")

###### BLIMP DATASET ######
blimp_dir = data_dir / "blimp"

individual_phenomena = []

phenomena = sorted(list(blimp_dir.iterdir()))
if ".DS_Store" in phenomena:
    phenomena.remove(".DS_Store")

for phenomenom in phenomena:
    df_json = pd.read_json(phenomenom, lines=True, orient="records")
    individual_phenomena.append(df_json)
blimp_df = pd.concat(individual_phenomena)
blimp_df["index"] = range(0, len(blimp_df))

blimp_sample_df = blimp_df.groupby("UID").sample(10, replace=False, random_state=42)


def parse_model_id(model_id, model_stage):
    
    model_family = ""  # olmo or gemma?
    model_version = ""  # for olmo 2 or 3? or gemma version

    parts = model_id.split("-")

    if "olmo" in model_id.lower():
        model_family = "olmo"
        if len(parts) > 1 and parts[1].replace(".", "").isdigit():
            model_version = parts[1]
    elif "gemma" in model_id.lower():
        model_family = "gemma"
        if len(parts) > 1 and parts[1].replace(".", "").isdigit():
            model_version = parts[1]
    elif "llama" in model_id.lower():
        model_family = "llama"
        if len(parts) > 1 and parts[1].replace(".", "").isdigit():
            model_version = parts[1]
    elif "qwen" in model_id.lower():
        model_family = "qwen"
        if len(parts) > 1:
            match = re.search(r"\d+(\.\d+)?", parts[0])
            if match:
                model_version = match.group()
    elif "kimi" in model_id.lower():
        model_family = "kimi"
        if len(parts) > 2:
            match = re.search(r"\d+(\.\d+)?", parts[1])
            if match:
                model_version = match.group()
    elif "pythia" in model_id.lower():
        model_family = "pythia"
    elif "k2-v2" in model_id.lower():
        model_family = "k2v2"
    else:
        raise ValueError(f"Unknown model family: {model_id}")

    checkpoint_step = None
    ingredient_step = None
    checkpoint_stage = None
    tokens_step = None
    if model_stage != "main":

        if model_family == "k2v2":
            checkpoint_stage, checkpoint_step = extract_k2v2_stage(
                model_stage, model_id
            )
        else:
            for part in model_stage.split("-"):
                if part.startswith("stage"):
                    checkpoint_stage = int(part.replace("stage", ""))
                if part.startswith("step"):
                    checkpoint_step = int(part.replace("step", ""))
                if part.startswith("ingredient"):
                    ingredient_step = int(part.replace("ingredient", ""))
                if part.startswith("tokens"):
                    tokens_step = int(part.replace("tokens", "").replace("B", ""))

    return (
        model_family,
        model_version,
        checkpoint_step,
        checkpoint_stage,
        ingredient_step,
        tokens_step,
    )


def delete_local_cache(model_id, model_stage):
    cache_info = scan_cache_dir()

    revisions_to_delete = []
    for entry in cache_info.repos:
        if model_id in entry.repo_id and entry.repo_type == "model":
            for revision in entry.revisions:
                for ref in revision.refs:
                    if ref == model_stage:
                        print(
                            f"Found revision to delete: {model_id} {model_stage} at {revision.commit_hash}"
                        )
                        revisions_to_delete.append(revision.commit_hash)

    if len(revisions_to_delete) > 0:
        delete_strategy = cache_info.delete_revisions(*revisions_to_delete)
        print(f"Will free {delete_strategy.expected_freed_size_str}.")
        delete_strategy.execute()


def is_valid_csv(file_path):
    """Check if a CSV file is valid and can be read by pandas."""
    try:
        this_df = pd.read_csv(file_path)
        if len(this_df) > 0:
            return True
        return False
    except Exception as e:
        return False


def check_output_files(model_id, model_stage, experiment_name):
    """Check if output files already exist and are valid CSVs."""
    blimp_output_file = output_dir / experiment_name / "blimp_results" / f"blimp_model_{model_id}_{model_stage}.csv"
    fb_output_file = output_dir / experiment_name / "fb_results" / f"fb_model_{model_id}_{model_stage}.csv"
    rm_output_file = output_dir / experiment_name / "rm_results" / f"rm_model_{model_id}_{model_stage}.csv"

    filenames = {
        "blimp": blimp_output_file,
        "fb": fb_output_file,
        "rm": rm_output_file,
    }

    do_blimp = True
    do_fb = True
    do_rm = True

    if blimp_output_file.exists() and is_valid_csv(blimp_output_file):
        print(f"BLIMP output file {blimp_output_file} already exists, skipping BLIMP evaluation.")
        do_blimp = False
    if fb_output_file.exists() and is_valid_csv(fb_output_file):
        print(f"False Belief output file {fb_output_file} already exists, skipping FB evaluation.")
        do_fb = False
    if rm_output_file.exists() and is_valid_csv(rm_output_file):
        print(f"Recursive Mindreading output file {rm_output_file} already exists, skipping RM evaluation.")
        do_rm = False

    return filenames, do_blimp, do_fb, do_rm


def main(args=None):
    """Main function to obtain model outputs for two datasets: False Belief and BLIMP."""

    (
        model_family,
        model_version,
        checkpoint_step,
        checkpoint_stage,
        ingredient_step,
        tokens_step,
    ) = parse_model_id(args.model_id, args.model_stage)

    model_details = {
        "model_family": model_family,
        "model_version": model_version,
        "checkpoint_step": checkpoint_step,
        "checkpoint_stage": checkpoint_stage,
        "ingredient_step": ingredient_step,
        "tokens_step": tokens_step,
        "model_size": args.model_size,
    }

    print(
        f"Computing probabilities for model:\n\n\
        Family: {model_family}\n \
        Version: {model_version}\n \
        Checkpoint Step: {checkpoint_step}\n \
        Checkpoint Stage: {checkpoint_stage}\n \
        Ingredient Step: {ingredient_step}\n \
        Tokens Step: {tokens_step}\n"
    )

    model_precision = torch.float32

    if args.model_size > 16:
        if "gemma" in args.model_id:
            model_precision = torch.bfloat16
        else:
            model_precision = torch.float16

    if model_family == "olmo":
        # OLMo models
        repo = "allenai/"
    elif model_family == "gemma":
        # GEMMA models
        repo = "google/"
    elif model_family == "llama":
        # LLaMa models
        repo = "meta-llama/"
    elif model_family == "qwen":
        # Qwen models
        repo = "qwen/"
    elif model_family == "kimi":
        # Qwen models
        repo = "moonshotai/"
    elif model_family == "pythia":
        # Pythia models
        repo = "EleutherAI/"
    elif model_family == "k2v2":
        # K2V2 models
        repo = "LLM360/"
    else:
        raise ValueError(f"Unknown model host: {args.model_id}")

    print("Checking which evaluations to run ...")
    output_filenames, do_blimp, do_fb, do_rm = check_output_files(args.model_id, args.model_stage, args.experiment_name)
    if not do_blimp and not do_fb and not do_rm:
        print("All output files already exist and are valid, exiting ...")
        return

    print("Overview of tasks to run: ")
    if do_blimp:
        print("  - BLIMP")
    if do_fb:
        print("  - False Belief")
    if do_rm:
        print("  - Recursive Mindreading")

    model = AutoModelForCausalLM.from_pretrained(
        f"{repo}{args.model_id}",
        revision=args.model_stage,
        dtype=model_precision,
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(f"{repo}{args.model_id}", trust_remote_code=True)

    bow_token_ids = []
    for token_id in range(tokenizer.vocab_size):
        token_str = tokenizer.decode([token_id])

        if (
            token_str.startswith("Ġ")
            or token_str.startswith("▁")
            or (token_str.startswith(" ") and len(token_str) > 1)
        ):
            bow_token_ids.append(token_id)

    model.eval()
    print("Done loading model ...")

    # False-belief dataset
    if do_fb:
        print("Running the False Belief dataset ...")
        fb_output = obtain_outputs(
            model, tokenizer, fb_data, bow_token_ids, model_details, task="fb"
        )
        fb_output_df = pd.DataFrame(fb_output)

        # log false belief results
        fb_output_df.to_csv(
            output_filenames['fb'],
            header=True,
            index=False,
        )

    # BLIMP dataset
    if do_blimp:
        print("Running the BLIMP dataset ...")
        blimp_output = obtain_outputs(
            model, tokenizer, blimp_sample_df, bow_token_ids, model_details, task="blimp"
        )
        blimp_output_df = pd.DataFrame(blimp_output)

        # log blimp results
        blimp_output_df.to_csv(
            output_filenames['blimp'],
            header=True,
            index=False,
        )

    # Recursive Mindreading dataset
    if do_rm:
        print("Running the Recursive Mindreading dataset ...")
        rm_data = read_rm()
        rm_output = obtain_pmis(model, tokenizer, rm_data, model_details)
        rm_output_df = pd.DataFrame(rm_output)

        # log rm results
        rm_output_df.to_csv(
            output_filenames['rm'],
            header=True,
            index=False,
        )

    print("All done, deleting my own cache (myself) :)")
    delete_local_cache(args.model_id, args.model_stage)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate language models on False Belief and BLIMP datasets"
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="OLMo-2-0425-1B",
        choices=all_model_ids,
        help="Model ID to evaluate",
    )
    parser.add_argument(
        "--model_stage",
        type=str,
        default="main",
        help="If given, load not the 'main' model checkpoints but a particular stage ID to evaluate",
    )
    parser.add_argument(
        "--model_size",
        type=float,
        default=1,
        help="Choose the model size in billions, in floats",
    )
    parser.add_argument(
        "--experiment_name",
        type=str,
        default="default_experiment",
        help="Name of the experiment (should correspond to a folder in 'runs')",
    )

    args = parser.parse_args()

    Path.mkdir(output_dir / args.experiment_name / "fb_results", exist_ok=True, parents=True)
    Path.mkdir(output_dir / args.experiment_name / "blimp_results", exist_ok=True, parents=True)
    Path.mkdir(output_dir / args.experiment_name / "rm_results", exist_ok=True, parents=True)

    main(args=args)
