import json
import re
import json
import os
import re
import argparse

import torch
from tqdm import tqdm

from ar import ActivationReasoning, LogicConfig
from ar.config import LogicConfig
from ProverQA_logic import transform

with open("data/ProverQA_easy.json", "r") as f:
    test_easy = json.load(f)
with open("data/ProverQA_medium.json", "r") as f:
    test_medium = json.load(f)
with open("data/ProverQA_hard.json", "r") as f:
    test_hard = json.load(f)
with open("data/ProverQA_provergen-5000.json", "r") as f:
    train = json.load(f)


def extract_fields(text: str):
    # Use regex to capture the three parts
    context_match = re.search(r"Context:\s*(.*?)\n\nQuestion:", text, re.DOTALL)
    question_match = re.search(r"Question:\s*(.*?)\n\nOptions:", text, re.DOTALL)
    options_match = re.search(r"Options:\s*(.*)", text, re.DOTALL)

    context = context_match.group(1).strip() if context_match else None
    question = question_match.group(1).strip() if question_match else None
    options = options_match.group(1).strip().splitlines() if options_match else None

    # Clean up options (remove empty lines, extra spaces)
    if options:
        options = [opt.strip() for opt in options if opt.strip()]

    return {"context": context, "question": question, "options": options}


def add_steering_concepts(rules: list[tuple]) -> dict:
    final_rules = {}
    for rule in rules:
        final_rules[rule] = None

    final_rules[("true",)] = "Answer: True"
    final_rules[("false",)] = "Answer: False"
    final_rules[("uncertain",)] = "Answer: Uncertain"

    return final_rules


def plain_answer(letter: str) -> str:
    if letter == "A":
        return "true"
        
    elif letter == "B":
        return "false"
        
    elif letter == "C":
        return "uncertain"
        


sentences = [
    "The sun dipped below the horizon, painting the sky in shades of orange and pink.",
    "She opened the old book and found a letter tucked between the pages.",
    "A sudden breeze carried the scent of jasmine through the open window.",
    "He had always dreamed of traveling the world, one city at a time.",
    "The cat leapt gracefully onto the windowsill, watching the street below.",
    "Despite the storm, the lighthouse stood tall and unwavering.",
    "The classroom buzzed with excitement on the last day of school.",
    "She carefully placed the freshly baked pie on the cooling rack.",
    "The distant sound of waves crashing soothed his restless thoughts.",
    "Under the oak tree, they buried a time capsule filled with memories.",
    "The train whistled as it pulled out of the station into the night.",
    "He scribbled notes furiously, afraid he might forget the brilliant idea.",
    "Snowflakes drifted lazily, settling into a soft white blanket on the ground.",
    "The puppy wagged its tail, overjoyed to see its owner return home.",
    "They lit a lantern and watched it float gently into the sky.",
    "The aroma of fresh coffee filled the small café on the corner.",
    "She laughed so hard that tears rolled down her cheeks.",
    "The mountain peak was shrouded in mist, mysterious and inviting.",
    "A violinist played a haunting melody in the quiet subway station.",
    "The little boy clutched his balloon tightly, afraid it might fly away.",
    # 10 sentences with is, then, and, or, not
    "Happiness is not something you buy, it is something you create.",
    "If the sky is clear, then we can see the stars tonight.",
    "She wanted tea and cookies, not coffee and cake.",
    "The choice is yours: take the risk or stay safe.",
    "This book is old and fragile, but its story is timeless.",
    "The plan is simple: work hard and then rest well.",
    "It is not the strongest or the fastest who always succeed.",
    "Friendship is built on trust and respect, not fear or doubt.",
    "There is always a way forward, even if it is not obvious at first.",
    "The movie is long and slow, but the ending is worth it.",
    # 9 Answer sentences
    "After reviewing the problem carefully. Answer: True.",
    "Based on the explanation provided. Answer: False.",
    "Considering all the given details. Answer: Uncertain.",
    "When you solve the equation step by step. Answer: True.",
    "Looking at the data in the chart. Answer: False.",
    "From the reasoning above. Answer: Uncertain.",
    "If you follow the logical sequence. Answer: True.",
    "Given the context of the passage. Answer: False.",
    "After comparing all possibilities. Answer: Uncertain.",
]


def eval(
    ar_model,
    sample_ds,
    num_samples,
    sentences,
    config,
    model_hyp,
    concept_dict,
    cot: bool = False,
    tracking: dict = None,
):
    tqdm_bar = tqdm(
        range(num_samples), total=num_samples, desc=f"{tracking}", unit="iter"
    )
    results = []
    correct_samples = []
    correct_solver = []
    for i in tqdm_bar:
        concepts = transform([sample_ds[i]["conclusion_fol"]], mode="phrases") + [
            "is",
            "do",
            "have",
            "but not both",
            "then",
            "and",
            "or",
            "not",
            "Answer: True",
            "Answer: False",
            "Answer: Uncertain",
        ]
        question = [sample_ds[i]["question"].split("?")[0] + "?"]
        train_sampels = concepts + question + sentences
        rules = add_steering_concepts(
            transform(sample_ds[i]["nl2fol"].values(), mode="flatten")
        )

        ar_model.concepts = concepts
        ar_model.search(inputs=train_sampels, reset_cache=True, batch_size=20)

        ar_model._reset_reasoner(rules)
        if ar_model.config.steering_factor == 0 and not cot:
            question = (
                sample_ds[i]["context"] + " " + sample_ds[i]["question"] + " Answer:"
            )
        elif ar_model.config.steering_factor == 0 and cot:
            question = (
                sample_ds[i]["context"]
                + " "
                + sample_ds[i]["question"]
                + " Think step by step and finish with: 'Answer:'\n"
            )
        else:
            question = sample_ds[i]["question"] + " Answer:"

            ar_model.configure(config=config, concepts=concepts)
            for concept in concept_dict:
                ar_model._al_concepts.concept_dict[concept] = concept_dict[concept]

        q1 = question.split("?")[0]
        q_len = ar_model.tokenizer([q1], return_tensors="pt").input_ids.shape[1]
        total_len = ar_model.tokenizer([question], return_tensors="pt").input_ids.shape[
            1
        ]
        # we mask the question part for detection
        detection_masks = [torch.tensor([0] * q_len + [1] * (total_len - q_len))]
        ar_model.reset_conv()

        out = ar_model.generate(
            [question],
            detection_masks=detection_masks,
            verbose=False,
            model_hyp=model_hyp,
            return_meta_data=True,
        )

        correct_samples.append(
            1 if plain_answer(sample_ds[i]["answer"]) in out[0][0].lower() else 0
        )
        correct_solver.append(
            1
            if out[1][0]["rules"][0]
            and plain_answer(sample_ds[i]["answer"])
            in out[1][0]["rules"][0][0].split(" ")[-1].lower()
            else 0
        )
        tqdm_bar.set_postfix(
            Gen_Acc=f"{sum(correct_samples) / len(correct_samples) if len(correct_samples) != 0 else 0:.4f}",
            Sol_Acc=f"{sum(correct_solver) / len(correct_solver) if len(correct_solver) != 0 else 0:.4f}",
        )

        results.append(
            {
                "id": i,
                "rules": {str(r): s for r, s in rules.items()},
                "context": sample_ds[i]["context"],
                "question": question,
                "detection": out[1][0]["concepts"],
                "Solver": out[1][0]["rules"][0],
                "DS_GT": sample_ds[i]["answer"],
                "Model": out[0],
                "Generation_correct": 1
                if plain_answer(sample_ds[i]["answer"]) in out[0][0].lower()
                else 0,
                "Solver correct": 1
                if out[1][0]["rules"][0]
                and plain_answer(sample_ds[i]["answer"])
                in out[1][0]["rules"][0][0].split(" ")[-1].lower()
                else 0,
            }
        )
    return results


def init(
    model_name: str,
    steering: bool = False,
    dataset_type: str = "easy",
    cot: bool = False,
):
    concept_dict = {}
    if model_name == "gemma_base":
        model_config = {
            "model_name": "google/gemma-2-9b",
            "sae_name": "gemma-scope-9b-pt-res-canonical",
            "hookpoint": "layer_20/width_131k/canonical",
            "layer": 20,
        }
        concept_dict = {
            "Answer: False": {
                "indices": [90687, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                "weights": [34.66799545288086, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            },
            "Answer: True": {
                "indices": [88978, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                "weights": [29.999292373657227, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            },
            "Answer: Uncertain": {
                "indices": [43101, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                "weights": [100, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            },
        }
    elif model_name == "gemma_it":
        model_config = {
            "model_name": "google/gemma-2-9b-it",
            "sae_name": "gemma-scope-9b-pt-res-canonical",
            "hookpoint": "layer_20/width_131k/canonical",
            "layer": 20,
        }
    elif model_name == "llama_31_base":
        model_config = {
            "model_name": "meta-llama/Meta-Llama-3.1-8B",
            "sae_name": "EleutherAI/sae-llama-3.1-8b-64x",
            "layer": 23,
            "hookpoint": "layers.23",
        }
        concept_dict = {
            "Answer: False": {
                "indices": [172278, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                "weights": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            },
            "Answer: True": {
                "indices": [225923, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                "weights": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            },
            "Answer: Uncertain": {
                "indices": [47613, 117731, 0, 0, 0, 0, 0, 0, 0, 0],
                "weights": [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            },
        }
    elif model_name == "llama_31_it":
        model_config = {
            "model_name": "meta-llama/Llama-3.1-8B-Instruct",
            "sae_name": "EleutherAI/sae-llama-3.1-8b-64x",
            "layer": 23,
            "hookpoint": "layers.23",
        }

    experiment_dir = f"output/experiments/proverqa/{model_name}"
    os.makedirs(experiment_dir, exist_ok=True)
    os.makedirs(
        f"{experiment_dir}/sae_latents/{'AL' if steering else 'Base'}{'_cot' if cot else ''}",
        exist_ok=True,
    )

    if model_config["model_name"] == "meta-llama/Meta-Llama-3.1-8B":
        print("Loading Llama3.1 8B configs")
        config = LogicConfig(
            search_concept_type="word",
            search_concept_token="all",
            search_strategy="top_k",
            search_top_k=10,
            search_top_k_order="unique_first",
            detection_top_k_concepts=3,
            detection_top_k_output=10,
            reasoner_rules_checking="open_world",
            steering_factor={
                "Answer: True": 0.5,
                "Answer: False": 0.5,
                "Answer: Uncertain": 4,
            }
            if steering
            else 0,
            steering_top_k_rule=2,
            steering_methodology="mean_shift",
            steering_weighting_function="uniform",
        )
    if model_config["model_name"] == "meta-llama/Llama-3.1-8B-Instruct":
        print("Loading Llama3.1 8B it configs")
        config = LogicConfig(
            search_concept_type="word",
            search_concept_token="all",
            search_strategy="top_k",
            search_top_k=10,
            search_top_k_order="unique_first",
            detection_top_k_concepts=3,
            detection_top_k_output=10,
            reasoner_rules_checking="open_world",
            steering_factor=10.0 if steering else 0,
            steering_top_k_rule=1,
            steering_methodology="mean_shift",
        )
    elif model_config["model_name"] == "google/gemma-2-9b":
        # print("Loading Gemma2 9B configs")
        config = LogicConfig(
            search_concept_type="word",
            search_concept_token="all",
            search_strategy="top_k",
            search_top_k=10,
            search_top_k_order="unique_first",
            detection_top_k_concepts=3,
            detection_top_k_output=10,
            reasoner_rules_checking="open_world",
            steering_factor=10.0 if steering else 0,
            steering_top_k_rule=1,
            steering_methodology="mean_shift",
        )
    elif model_config["model_name"] == "google/gemma-2-9b-it":
        print("Loading Gemma2 9B it configs")
        config = LogicConfig(
            search_concept_type="word",
            search_concept_token="all",
            search_strategy="top_k",
            search_top_k=10,
            search_top_k_order="unique_first",
            detection_top_k_concepts=3,
            detection_top_k_output=10,
            reasoner_rules_checking="open_world",
            steering_factor=10.0 if steering else 0,
            steering_top_k_rule=1,
            steering_methodology="mean_shift",
        )
    else:
        print("No AL config loaded")

    concepts = transform([test_easy[0]["conclusion_fol"]], mode="phrases")
    ar_model = ActivationReasoning(
        rules={},
        concepts=concepts,
        config=config,
        cache_dir=f"{experiment_dir}/sae_latents/{'AL' if steering else 'Base'}{'_cot' if cot else ''}",
        **model_config,
        verbose=False,
    )
    ar_model.search(inputs=concepts, reset_cache=True, batch_size=20)

    if dataset_type == "easy":
        sample_ds = test_easy
    elif dataset_type == "medium":
        sample_ds = test_medium
    elif dataset_type == "hard":
        sample_ds = test_hard

    if "llama" in model_config["model_name"]:
        model_hyp = {
            "do_sample": False,
            "temperature": None,
            "top_k": None,
            "top_p": None,
            "max_new_tokens": 2 if not cot else 512,
        }
    else:
        model_hyp = {
            "do_sample": False,
            "temperature": None,
            "top_k": None,
            "top_p": None,
            "max_new_tokens": 1 if not cot else 512,
        }
    return (
        ar_model,
        model_name,
        experiment_dir,
        config,
        concept_dict,
        sample_ds,
        model_hyp,
    )


def main():
    parser = argparse.ArgumentParser(
        description="A script that accepts model configuration via command line."
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="gemma_base",
        help='The name of the model to use (e.g., "gemma_base").',
    )
    parser.add_argument(
        "--cot",
        action="store_true",
        default=False,
        help="Enable Chain of Thought (sets cot to True if present).",
    )
    parser.add_argument(
        "--steering",
        action="store_true",
        default=False,
        help="Enable steering (sets steering to True if present).",
    )
    parser.add_argument(
        "--difficulties",
        type=str,
        nargs="+",  # This gathers 1 or more items into a list
        default=["easy", "medium", "hard"],
        help="A space-separated list of difficulty levels (e.g., --difficulties easy medium hard).",
    )

    parser.set_defaults(cot=False)
    parser.set_defaults(steering=False)

    # 3. Parse the arguments from the command line
    args = parser.parse_args()

    # --- Use the arguments in your script ---
    print("--- Script Configuration ---")
    print(f"Model Name: {args.model_name}")
    print(f"CoT Enabled: {args.cot}")
    print(f"Steering Enabled: {args.steering}")
    print(f"Difficulties: {args.difficulties}")
    print("--------------------------")

    num_samples = 500
    model_name = args.model_name  # "gemma_base"
    cot = args.cot  # False
    steering = args.steering  # True

    for dataset_type in args.difficulties:
        (
            ar_model,
            model_name,
            experiment_dir,
            config,
            concept_dict,
            sample_ds,
            model_hyp,
        ) = init(
            model_name=model_name, steering=steering, dataset_type=dataset_type, cot=cot
        )
        results = eval(
            ar_model=ar_model,
            sample_ds=sample_ds,
            num_samples=num_samples,
            sentences=sentences,
            config=config,
            model_hyp=model_hyp,
            concept_dict=concept_dict,
            cot=cot,
            tracking={
                "model": model_name,
                "cot": cot,
                "steering": steering,
                "dataset": dataset_type,
            },
        )
        with open(
            experiment_dir
            + f"/{dataset_type}{'_AL' if steering else ''}{'_cot' if cot else ''}.json",
            "w",
        ) as f:
            json.dump(results, f)


if __name__ == "__main__":
    main()