import dataclasses as dc
import json
import os
from pprint import pprint
from pathlib import Path
from typing import Literal, Any

import torch
import random
from . import inference
from .reasoning import make_reasoner, ThoughtImpl, task_from_name, ReasoningTask
from .model import make_llm, Distrubute, LLM


@dc.dataclass
class EvalArgs:

    context_length: int
    infernece_type: inference.Approach
    reasoner_args: dict = dc.field(default_factory=dict)
    inference_args: dict = dc.field(default_factory=dict)
    evaluator_args: dict = dc.field(default_factory=dict)


@torch.inference_mode()
def evaluate_reasoning(
    checkpoint: Path | str,
    task: ReasoningTask | str,
    impl: ThoughtImpl,
    context_length: int,
    batch_size: int = 64,
    split: Literal['train', 'val', 'test'] | str = 'val',
    max_n: int | None = None,
    tokenizer_dir: Path | str | None = None,
    distribute: Distrubute = "auto",
    save_results: Path | str | None = None,
    evaluator_args: dict = {},
    reasoner_args: dict = {},
    inference_args: dict | None = None,
    reflection_args: dict | None = None,
    verbose: int = 1,
    litckpt_name: str | None = None,  # test a specific (the best) lightening checkpoint in RL.
):

    llm = make_llm(str(checkpoint), tokenizer_dir=tokenizer_dir, distribute=distribute)
    if litckpt_name is not None:  # use lightening checkpoint file
        _load_litckpt(Path(checkpoint), litckpt_name, llm=llm)

    if isinstance(task, str):
        task = task_from_name(task)
    
    instances = task.get_instances(split)
    if not instances:
        print(f"Warning: no {split} data for the task.")
        return None
    
    reasoner = make_reasoner(
        llm, impl, context_length,
        task=task,
        inference_args=inference_args,
        reflection_args=reflection_args,
        **reasoner_args
    )
    evaluator = task.evaluator(**evaluator_args)
    evaluator.attach(reasoner)
    if max_n is not None:
        random.shuffle(instances)
    else:
        max_n = len(instances)

    if verbose >= 1:
        _arg_text = ', '.join(f'{k}={v}' for k, v in reasoner_args.items())
        if _arg_text:
            _arg_text = " with " + _arg_text
        print(f"Decode {_arg_text}.")

    refs: list[dict[str, Any]]

    def get_ref_dict(idx: tuple[int, ...]):
        i = idx[0]
        return refs[i]
    
    reasoner.ref = get_ref_dict

    for i in range(0, max_n, batch_size):
        cots = instances[i: i + batch_size]
        inputs = [task.style.apply_input(**cot.as_dict()) for cot in cots]
        refs = [task.get_ref_dict(cot) for cot in cots]
        input_tokens = reasoner.preprocess(inputs)
        thought, outcome = reasoner.__call__(input_tokens)
        if verbose >= 1:
            print(f"Evaluated {i + len(cots)} / {max_n} {split} samples.")
        if verbose >= 2:
            evaluator.present_results()

    if verbose >= 1:
        print(f"Complete: Evaluated {max_n} / {max_n} {split} samples.")
    
    if save_results:
        evaluator.save(save_results)
        if verbose >= 1:
            print(f"Results saved to {save_results}")

    if not save_results or verbose >= 1:
        results = evaluator.get_results()
        pprint(results)

    reasoner.close()

    return results


def _load_litckpt(ckpt_dir: Path, file_name: str, strict=True, **models: torch.nn.Module):
    assert ckpt_dir is not None
    if file_name.endswith(".ckpt"):
        file_path = ckpt_dir / file_name
    else:
        files = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt") and f.startswith(file_name)]
        if len(files) != 1:
            raise FileNotFoundError(f"Require one .ckpt file starts with \"{file_name}\", "
                                    f"but {len(files)} were found.")
        file_path = ckpt_dir / files[0]
    ckpt_data = torch.load(file_path, weights_only=True)
    state_dict: dict[str, Any] = ckpt_data["state_dict"]
    assert isinstance(state_dict, dict)
    for name, model in models.items():
        prefix = name + '.'
        state_dict = {k.removeprefix(prefix): v for k, v in state_dict.items() if k.startswith(prefix)}
        model.load_state_dict(state_dict, strict=strict)
