import json
import re
import pickle
from dataclasses import dataclass
from pprint import pprint
from pathlib import Path
from typing import Optional, List

import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import tyro
from tqdm import tqdm

from data import load_data, build_inputs
from ablations import get_ablation_args


@dataclass
class Config:
    batch_size: int = 32
    split_digit: bool = True
    ablation: Optional[str] = None

    data_path: str = "../../../data/toycos/data"
    data_split: str = "test"
    ckpt: str = "none"
    output_dir: str = "./stats"

    save_outputs: bool = False


def get_num(text: str, n_digits: Optional[int] = None) -> Optional[List[float]]:
    nums = re.findall(r"(\d+\.\d+)", text)
    ys = [float(n) for n in nums]
    if n_digits is not None:
        ys = ys[:n_digits]
        if len(ys) < n_digits:
            return None
    return ys


def np_evaluate(gt_output, pred_output):
    rmse = mean_squared_error(gt_output, pred_output, squared=False)
    mae = mean_absolute_error(gt_output, pred_output)

    return rmse, mae


def metric_with_missing_rate(gt_text, predicted_text):
    output_data = []
    gt_data = []
    missing_count = 0

    for gt, pred in zip(gt_text, predicted_text):
        gt_num = get_num(gt)
        assert gt_num is not None, f"GT number is None, gt: {gt}"
        pred_num = get_num(pred, n_digits=len(gt_num))
        if gt_num is not None and pred_num is not None:
            output_data.append(pred_num)
            gt_data.append(gt_num)
        else:
            missing_count += 1

    output_data = np.array(output_data)
    gt_data = np.array(gt_data)
    output = np.reshape(output_data, [len(output_data), 1])
    gt_output = np.reshape(gt_data, [len(gt_data), 1])

    rmse, mae = np_evaluate(gt_output, output)
    missing_rate = missing_count / len(gt_text)

    return rmse, mae, missing_rate


def main():
    args = tyro.cli(Config)
    args.max_length = 2048
    args.key = Path(args.ckpt).stem
    args.seed = int(args.key.split("_")[0])
    data_name = args.data_split

    args.use_cont_loss = False
    ablation_args = {}
    if args.ablation is not None:
        ablation_args = get_ablation_args(args.ablation)
    if "split_digit" in ablation_args:
        args.split_digit = ablation_args["split_digit"]
    args.use_cont_loss = ablation_args.get("use_cont_loss", False)

    if "llama" in args.ckpt.lower():
        args.model = "meta-llama/Llama-3.2-1B-Instruct"
    else:
        args.model = "HuggingFaceTB/SmolLM2-135M-Instruct"

    root = Path(args.output_dir)
    root.mkdir(exist_ok=True, parents=True)

    model = AutoModelForCausalLM.from_pretrained(
        args.ckpt,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        _attn_implementation="flash_attention_2",
    )
    model = model.to("cuda")
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    if not tokenizer.pad_token:
        if "llama" in args.ckpt.lower():
            tokenizer.pad_token = "<|finetune_right_pad_id|>"
        else:
            tokenizer.pad_token = tokenizer.unk_token

    data, _ = load_data(args, tokenizer)
    data = data.data

    all_outs = {"hypo": [], "tgt": []}

    @torch.no_grad()
    def run_batch(queries, targets):
        tokenizer.padding_side = "left"
        convs = [[{"role": "user", "content": query}] for query in queries]
        texts = [
            tokenizer.apply_chat_template(
                conv, tokenize=False, add_generation_prompt=True
            )
            for conv in convs
        ]
        inputs = build_inputs(
            tokenizer,
            texts,
            padding_side="left",
            split_digit=args.split_digit,
            device="cuda",
        )

        generate_ids = model.generate(
            **inputs,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=40,
            eos_token_id=tokenizer.eos_token_id,
        )
        generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
        responses = tokenizer.batch_decode(
            generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return responses, targets

    batch = []
    for i, row in tqdm(enumerate(data), total=len(data)):
        batch.append(row["conv"])

        while len(batch) >= args.batch_size:
            rest = []
            if len(batch) > args.batch_size:
                rest = batch[args.batch_size :]
                batch = batch[: args.batch_size]
            responses, targets = run_batch([b[0] for b in batch], [b[1] for b in batch])
            all_outs["hypo"].extend(responses)
            all_outs["tgt"].extend(targets)

            batch = rest

    if len(batch) > 0:
        responses, targets = run_batch([b[0] for b in batch], [b[1] for b in batch])
        all_outs["hypo"].extend(responses)
        all_outs["tgt"].extend(targets)

    # log mean, std
    rmse, mae, mr = metric_with_missing_rate(all_outs["tgt"], all_outs["hypo"])
    stats = {
        "rmse": float(rmse),
        "mae": float(mae),
        "missing_rate": float(mr),
    }
    pprint(stats)

    with open(root / f"{args.key}.json", "w") as f:
        json.dump(stats, f, indent=4)

    if args.save_outputs:
        save_root = root / "outputs"
        save_root.mkdir(exist_ok=True, parents=True)
        res = {"hypo": all_outs["hypo"], "tgt": all_outs["tgt"], "data": data}
        with open(save_root / f"{args.key}.pkl", "wb") as f:
            pickle.dump(res, f)


if __name__ == "__main__":
    main()
