import os
import json
import re
from typing import List, Dict, Any, Optional


def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f if line.strip()]


def save_jsonl(file_path: str, data: List[Dict[str, Any]]):
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            json.dump(item, f, ensure_ascii=False)
            f.write("\n")


def load_json(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)


# MATH dataset preprocessing
def preprocess_math_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        question = item.get("problem") or item.get("question", "")
        answer = item.get("answer", "").strip()

        if question and answer:
            processed.append({
                "question": question.strip(),
                "answer": answer,
                "problem_type": "math"
            })
    return processed

# MATHQA dataset preprocessing


def preprocess_mathqa_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        question = item.get("Problem", "").strip()
        correct_option = item.get("correct", "").strip().lower()
        options = item.get("options", "")
        answer_text = ""
        if options and correct_option:
            option_map = {}
            for opt in options.split(","):
                opt = opt.strip()
                if ")" in opt:
                    key, val = opt.split(")", 1)
                    option_map[key.strip().lower()] = val.strip()
            answer_text = option_map.get(correct_option, "")

        if question and answer_text:
            processed.append({
                "question": question,
                "answer": answer_text,
                "problem_type": "math"
            })
    return processed

# SVAMP dataset preprocessing


def preprocess_svamp_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        body = item.get("Body", "").strip()
        question = item.get("Question", "").strip()
        full_question = f"{body}\n{question}" if body else question

        answer = str(item.get("Answer", "")).strip()

        if full_question and answer:
            processed.append({
                "question": full_question,
                "answer": answer,
                "problem_type": "math"
            })
    return processed

# ASDIV dataset preprocessing


def preprocess_asdiv_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        body = item.get("body", "").strip()
        question = item.get("question", "").strip()
        full_question = f"{body}\n{question}" if body else question

        raw_answer = item.get("answer", "").strip()
        cleaned_answer = re.split(r"[\(\)]", raw_answer)[0].strip()

        if full_question and cleaned_answer:
            processed.append({
                "question": full_question,
                "answer": cleaned_answer,
                "problem_type": "math"
            })
    return processed

# MAWPS dataset preprocessing


def preprocess_mawps_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        question = item.get("input", "").strip()
        answer = str(item.get("target", "")).strip()
        if question and answer:
            processed.append({
                "question": question,
                "answer": answer,
                "problem_type": "math"
            })
    return processed

# AIME dataset preprocessing


def preprocess_aime24_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        question = item.get("Question", "").strip()
        answer = str(item.get("Answer", "")).strip()
        if question and answer:
            processed.append({
                "question": question,
                "answer": answer,
                "problem_type": "math"
            })
    return processed


# FOLIO dataset preprocessing
def preprocess_folio_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        premises_nl = (item.get("premises") or "").strip()
        conclusion_nl = (item.get("conclusion") or "").strip()
        # Entailed / Contradicted / Uncertain
        label = (item.get("label") or "").strip()

        if not (premises_nl and conclusion_nl and label):
            continue

        prem_lines = [ln.strip()
                      for ln in premises_nl.splitlines() if ln.strip()]
        facts_text = "Facts:\n" + \
            "\n".join(f"{i+1}. {ln}" for i, ln in enumerate(prem_lines))

        question = (
            f"{facts_text}\n\n"
            f"Conclusion:\n{conclusion_nl}\n\n"
            "Question: Does the conclusion logically follow from the facts? "
            "Answer with exactly one of: True, False, Uncertain. Do not output other words."
        )

        processed.append({
            "question": question,
            "answer": label,
            "problem_type": "folio"
        })
    return processed

# AR-LSAT dataset preprocessing


def preprocess_arlsat_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        premises_nl = (item.get("context") or "").strip()
        question = (item.get("question") or "").strip()
        options = item.get("answers") or []
        label = item.get("label")

        if not question or not isinstance(options, list) or len(options) != 5:
            continue

        idx = int(label)

        opts_text = "\n".join(
            f"{i}) {str(options[i]).strip()}" for i in range(5))
        user_prompt = (
            (premises_nl + "\n") if premises_nl else ""
        ) + f"{question}\n\nOptions:\n{opts_text}\n\n" \
            "Answer with exactly one digit (0-4). Output only the number. Do not output other words."

        processed.append({
            "question": user_prompt,
            "answer": str(idx),
            "problem_type": "arlsat",
        })
    return processed

# LogiQA dataset preprocessing


def preprocess_logiqa_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        context = (item.get("premise") or "").strip()
        hypothesis = (item.get("hypothesis") or "").strip()
        label = (item.get("label") or "").strip()

        question = (
            f"{context}\n\n"
            f"Hypothesis:\n{hypothesis}\n\n"
            "Based on the facts above, choose the correct label for the hypothesis."
            "Answer with exactly one of: entailment, not-entailment . Do not output other words."
        )

        processed.append({
            "question": question,
            "answer": label,
            "problem_type": "logiqa"
        })
    return processed

# RECLOR dataset preprocessing


def preprocess_reclor_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        premises_nl = (item.get("context") or "").strip()
        question = (item.get("question") or "").strip()
        options = item.get("answers") or []
        label = item.get("label")

        if not question or not isinstance(options, list) or len(options) != 4:
            continue

        idx = int(label)

        opts_text = "\n".join(
            f"{i}) {str(options[i]).strip()}" for i in range(4))
        user_prompt = (
            (premises_nl + "\n") if premises_nl else ""
        ) + f"{question}\n\nOptions:\n{opts_text}\n\n" \
            "Answer with exactly one digit (0-3). Output only the number. Do not output other words."

        processed.append({
            "question": user_prompt,
            "answer": str(idx),
            "problem_type": "reclor"
        })
    return processed

# AbductionR dataset preprocessing


def preprocess_abductionr_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        context = (item.get("context") or "").strip()
        conclusion = (item.get("text") or "").strip()
        qcat = (item.get("QCat") or "").strip()
        label = ""

        if qcat == "0":
            label = "True"
        elif qcat == "0_0":
            label = "False"

        question = (
            f"{context}\n\n"
            f"Conclusion:\n{conclusion}\n\n"
            "Question: Does the conclusion logically follow from the context? "
            "Answer with exactly one of: True, False. Do not output other words. "
        )

        processed.append({
            "question": question,
            "answer": label,
            "problem_type": "abductionr"
        })
    return processed


def preprocess_rulearena_dataset(dataset: List[Dict[str, Any]], args) -> List[Dict[str, Any]]:
    def airline_preprocess(item: Dict[str, Any]) -> Dict[str, Any]:
        # Implement the preprocessing logic for the airline split
        # change key 'prompt' to 'question'
        item["question"] = item.pop("prompt")
        return item

    def nba_preprocess(item: Dict[str, Any]) -> Dict[str, Any]:
        # Implement the preprocessing logic for the NBA split
        item["question"] = item.pop("prompt")
        return item

    def tax_preprocess(item: Dict[str, Any]) -> Dict[str, Any]:
        # Implement the preprocessing logic for the tax split
        forms = [basic_forms]
        tax_payer = item['tax_payer_dict']
        if tax_payer["itemized"]:
            forms.append(itemized_forms)
        if tax_payer["self_employed"]:
            forms.append(self_employ_forms)
        if tax_payer["has_student_loans_or_education_expenses"]:
            forms.append(edu_forms)
        if tax_payer["child_and_dependent"]:
            forms.append(schedule_8812)
        forms = "".join(forms)
        tbd_fields = []
        for k, v in tax_payer["data"].items():
            forms = forms.replace(
                "$"+k, "$"+f"{v:,}" if not isinstance(v, str) else v)
            if v == "$TBD":
                tbd_fields.append(k)
        for fields in tbd_fields:
            tax_payer["data"].pop(fields)
        tbd_mark = "[__]"
        forms = forms.replace("$TBD", tbd_mark)
        forms = forms.replace("$forms", forms)

        forms = forms.replace("$name", tax_payer["name"])
        forms = forms.replace("$age", str(tax_payer["age"]))
        forms = forms.replace("$spouse_age", str(tax_payer["spouse_age"]))
        forms = forms.replace("$blind", str(tax_payer["blind"]))
        forms = forms.replace("$spouse_blind", str(tax_payer["spouse_blind"]))
        forms = forms.replace("$filing_status", tax_payer["filing_status"])
        forms = forms.replace("$itemized", str(tax_payer["itemized"]))
        forms = forms.replace("$num_qualifying_children", str(
            tax_payer["num_qualifying_children"]))
        forms = forms.replace("$num_other_dependents", str(
            tax_payer["num_other_dependents"]))

        return {
            "question": forms,
            "answer": item["answer"],
            "problem_type": item["problem_type"]
        }

    if args.data_split == 'airline':
        return [airline_preprocess(item) for item in dataset]
    elif args.data_split == 'nba':
        return [nba_preprocess(item) for item in dataset]
    elif args.data_split == 'tax':
        from datasets.rulearena.tax_forms import basic_forms, itemized_forms, self_employ_forms, edu_forms, schedule_8812
        return [tax_preprocess(item) for item in dataset]
    else:
        raise ValueError(f"Unknown RuleArena split: {args.data_split}")


# FLD dataset preprocessing
def preprocess_fld_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        context = (item.get("context") or "").strip()
        hypothesis = (item.get("hypothesis") or "").strip()
        label = (item.get("proof_label") or "").strip()

        question = (
            f"{context}\n\n"
            f"Hypothesis:\n{hypothesis}\n\n"
            "Based on the facts above, choose the correct label for the hypothesis."
            "Answer with exactly one of: proved, disproved, unknown. Do not output other words."
        )
        processed.append({
            "question": question,
            "answer": label,
            "problem_type": "fld"
        })
    return processed

# ProofWriter dataset preprocessing


def preprocess_proofwriter_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        context = (item.get("theory") or "").strip()
        hypothesis = (item.get("question") or "").strip()
        label = (item.get("answer") or "").strip()

        question = (
            f"{context}\n\n"
            f"Hypothesis:\n{hypothesis}\n\n"
            "Based on the facts above, choose the correct label for the hypothesis."
            "Answer with exactly one of: True, False, Unknown. Do not output other words."
        )
        processed.append({
            "question": question,
            "answer": label,
            "problem_type": "proofwriter"
        })
    return processed

# RuleTaker dataset preprocessing


def preprocess_ruletaker_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    processed = []
    for item in dataset:
        context = (item.get("context") or "").strip()
        hypothesis = (item.get("question") or "").strip()
        label = (item.get("label") or "").strip()

        question = (
            f"{context}\n\n"
            f"Hypothesis:\n{hypothesis}\n\n"
            "Based on the facts above, choose the correct label for the hypothesis."
            "Answer with exactly one of: entailment, not entailment . Do not output other words."
        )
        processed.append({
            "question": question,
            "answer": label,
            "problem_type": "ruletaker"
        })
    return processed


# Data loading functions for support extensions
DATASET_LOADERS = {
    "gsm8k": lambda args, data_dir: [
        {"question": item["question"],
            "answer": item["answer"], "problem_type": "math"}
        for item in load_jsonl(os.path.join(data_dir, "gsm8k", "main", f"{args.data_split}.jsonl"))
    ],
    "math": lambda args, data_dir: preprocess_math_dataset(
        load_jsonl(os.path.join(data_dir, "math", f"{args.data_split}.jsonl"))
    ),
    "mathqa": lambda args, data_dir: preprocess_mathqa_dataset(
        load_json(os.path.join(data_dir, "mathqa", f"{args.data_split}.json"))
    ),
    "svamp": lambda args, data_dir: preprocess_svamp_dataset(
        load_jsonl(os.path.join(data_dir, "SVAMP",
                   "data", f"{args.data_split}.jsonl"))
    ),
    "asdiv": lambda args, data_dir: preprocess_asdiv_dataset(
        load_jsonl(os.path.join(data_dir, "asdiv", f"{args.data_split}.jsonl"))
    ),
    "mawps": lambda args, data_dir: preprocess_mawps_dataset(
        load_jsonl(os.path.join(data_dir, "mawps", f"{args.data_split}.jsonl"))
    ),
    "aime24": lambda args, data_dir: preprocess_aime24_dataset(
        load_jsonl(os.path.join(data_dir, "aime24",
                   f"{args.data_split}.jsonl"))
    ),
    "folio": lambda args, data_dir: preprocess_folio_dataset(
        load_jsonl(os.path.join(data_dir, "folio", f"{args.data_split}.jsonl"))
    ),
    "arlsat": lambda args, data_dir: preprocess_arlsat_dataset(
        load_jsonl(os.path.join(data_dir, "arlsat",
                   f"{args.data_split}.jsonl"))
    ),
    "logiqa": lambda args, data_dir: preprocess_logiqa_dataset(
        load_jsonl(os.path.join(data_dir, "logiqa",
                   f"{args.data_split}.jsonl"))
    ),
    "reclor": lambda args, data_dir: preprocess_reclor_dataset(
        load_json(os.path.join(data_dir, "reclor", f"{args.data_split}.json"))
    ),
    "abductionr": lambda args, data_dir: preprocess_abductionr_dataset(
        load_jsonl(os.path.join(data_dir, "abductionr",
                   "data", f"{args.data_split}.jsonl"))
    ),
    "rulearena": lambda args, data_dir: preprocess_rulearena_dataset(
        load_jsonl(os.path.join(data_dir, "rulearena",
                   f"{args.data_split}.jsonl")),
        args
    ),
    "fld": lambda args, data_dir: preprocess_fld_dataset(
        load_jsonl(os.path.join(data_dir, "fld",
                   "data", f"{args.data_split}.jsonl"))
    ),
    "proofwriter": lambda args, data_dir: preprocess_proofwriter_dataset(
        load_jsonl(os.path.join(data_dir, "proofwriter",
                   "data", f"{args.data_split}.jsonl"))
    ),
    "ruletaker": lambda args, data_dir: preprocess_ruletaker_dataset(
        load_jsonl(os.path.join(data_dir, "ruletaker",
                   "data", f"{args.data_split}.jsonl"))
    ),
}


def load_data(dataset_name: str,
              args,
              max_samples: Optional[int] = None,
              data_dir: str = "./datasets") -> List[Dict[str, Any]]:

    if dataset_name not in DATASET_LOADERS:
        raise NotImplementedError(f"Unsupported dataset: {dataset_name}")
    else:
        examples = DATASET_LOADERS[dataset_name](args, data_dir)

    return examples[:max_samples] if max_samples else examples

# General dialogue constructor


def build_conversation(system_prompt: str, question: str) -> List[Dict[str, str]]:
    conversation = [{"role": "system", "content": system_prompt}]
    user_prompt = f"Q: {question}\nA:"
    conversation.append({"role": "user", "content": user_prompt})
    return conversation


def prepare_math_dataset_conversation(data_item: Dict[str, Any], args) -> List[Dict[str, str]]:
    return build_conversation(
        "Please answer the question. Return the final result as: Answer: \\boxed{...}.",
        data_item["question"]
    )


def prepare_logic_dataset_conversation(data_item: Dict[str, Any], args) -> List[Dict[str, str]]:
    system_prompt = (
        "You are a careful logical reasoner,Answer the question according to the requirements."
    )
    return build_conversation(
        system_prompt,
        data_item["question"]
    )


def prepare_rulearena_conversation(data_item: Dict[str, Any], args) -> List[Dict[str, str]]:
    if args.data_split == 'airline':
        from datasets.rulearena.prompt import system_prompt_airline, prompt_template_airline, reference_rules_airline
        prompt = prompt_template_airline.replace(
            "$reference_rules", reference_rules_airline)
        prompt = prompt.replace("$question_prompt", data_item['question'])
        prompt = prompt.replace("$example_prompt", "")
        conversation = [
            {"role": "system", "content": system_prompt_airline},
            {"role": "user", "content": prompt}
        ]
        return conversation
    elif args.data_split == "nba":
        from datasets.rulearena.prompt import system_prompt_nba, prompt_template_nba, reference_rules_nba
        prompt = prompt_template_nba.replace(
            "$reference_rules", reference_rules_nba)
        prompt = prompt.replace("$question", data_item["question"])
        prompt = prompt.replace("$example", "")
        conversation = [
            {"role": "system", "content": system_prompt_nba},
            {"role": "user", "content": prompt}
        ]
        return conversation
    elif args.data_split == "tax":
        from datasets.rulearena.prompt import system_prompt_tax, prompt_template_tax
        prompt = prompt_template_tax.replace("$forms", data_item["question"])
        prompt = prompt.replace("$example", "")
        conversation = [
            {"role": "system", "content": system_prompt_tax},
            {"role": "user", "content": prompt}
        ]
        return conversation
    else:
        raise ValueError(
            f"Unknown split: {args.data_split} when prepare conversation for rulearena")
