import os
import sys
from core.api import *
from litgpt.config import Config as GptConfig
from configs import Config, get_config_from_name as get_config
from exprutils import Command, parg, karg, run as _run


def _parse_resume(resume: bool, path: str | Path | None):
    return False if resume is False else (
        'auto' if path is None else Path(path)
    )


def _parse_devices(devices: str | None, default: Auto | int = 'auto') -> str | int:
    if devices is None:
        return default
    try:
        return int(devices)
    except ValueError:
        return devices


def _get_reasoning_config(config: Config):
    reasoning_cfg = config.reasoning
    if reasoning_cfg is None:
        print(f"reasoning has not been configured in {config.name}.")
        exit()
    return reasoning_cfg


def _text_data_group(name: str, eos: str | None = None) -> TextFileGroup:
    token_map: dict[str, str] = {}

    if name == 'tiny_story':
        if eos:
            token_map['<|endoftext|>'] = eos
        return TextFileGroup(
            prefix='TinyStory',
            root='data/TinyStory',
            files={
                'train': ['TinyStories-train.txt'],
                'val': ['TinyStories-valid.txt'],
            },
            token_map=token_map,
            encoding='utf-8',
        )
    else:
        raise KeyError(name)


def _out_dir(config: Config, cmd: str):
    return config.out_path(cmd)


class Prepare(Command):
    """
    Prepare the CoT data, pretraining data, and tokenizer. 
    """

    config = parg(str, "the name of experiment configuration")
    overwrite = karg(bool, "whether to overwrite existing files")

    def __call__(self):

        config = get_config(self.config)
        overwrite = self.overwrite
        special_tokens = config.special_tokens.copy()
        eos = config.special_tokens.eos
        if eos is None:
            print("Warning: The EOS token has not been specified.")

        outpath = config.out_path()
        os.makedirs(outpath, exist_ok=True)
        file_groups: list[TextFileGroup] = []

        for data_info in config.corpus:
            name = data_info['name'] if isinstance(data_info, dict) else data_info
            assert isinstance(name, str)
            try:
                if isinstance(data_info, dict):
                    task = reasoning_task(**data_info)
                elif reasoning_task_available(data_info):
                    task = reasoning_task(name)
                else:
                    group = _text_data_group(name, eos)
                    file_groups.append(group)
                    task = None
            except KeyError:
                supported = ('tiny_story',) + available_reasoning_tasks()
                supported = ', '.join("\"%s\"" % name for name in supported)
                print(
                    f"Warning: {name} is not a recognized data and will be ignored.",
                    f"The supported are {supported}."
                )
                continue

            if task is not None:
                special_tokens.incorporate(task.special_tokens())
                task.prepare_data(
                    files={
                        "train": [".txt"],  # we do not prepare SFT data
                        "val": [".txt"],
                    },
                    force_write=overwrite,
                    eos=(eos if eos else ''),
                )
                file_groups.append(TextFileGroup(
                    prefix=name,
                    root=task.root,
                    files={
                        'train': ['train.txt'],
                        'val': ['val.txt'],
                    },
                    token_map={},
                    encoding='utf-8',
                ))

        setup_text_files(
            config.get_pretrain_data_path(),
            *file_groups,
            force_write=overwrite,
        )

        # incorporate special tokens for reflection
        reflection_approachs: set[ReflectionApproach] = set()
        if isinstance(config.reasoning, dict):
            for r in config.reasoning.values():
                if r.reflection is not None:
                    reflection_approachs.add(r.reflection.approach)
        elif (r := config.reasoning) is not None:
            if r.train_refl is not None:
                reflection_approachs.add(r.train_refl.approach)
        for reflection_approach in reflection_approachs:
            reflection_class = get_reflection_class(reflection_approach)
            special_tokens.incorporate(reflection_class.required_vocab)
        
        if not config.use_pretrained_checkpoint:
            tokenizer_path = config.out_path('tokenizer')
            if (tokenizer_path / 'tokenizer.json').exists():
                print("Tokenizer has already existed.")
            else:
                vocab_size = config.model.vocab_size
                reserved = config.reserved_vocab_size
                if reserved > vocab_size:
                    raise ValueError(f"vocabulary size ({vocab_size}) should be greater than the reserved size ({reserved})")
                vocab_size -= reserved
                train_tokenizer(
                    r'(.*)\.txt',
                    config.get_pretrain_data_path('train'),
                    config.out_path('tokenizer'),
                    "BPE",
                    special_tokens,
                    kwargs=param.BPE(vocab_size=vocab_size)
                )
        else:
            raise NotImplementedError

        # prepare pretraining data
        tokenizer = Tokenizer(config.out_path('tokenizer'))
        textfiles = TextFiles(
            config.get_pretrain_data_path('train'),
            config.get_pretrain_data_path('val'),
        )
        textfiles.connect(tokenizer=tokenizer)
        # max_workers = 1 to avoid LitData deletion issue. This might be a LitData bug.
        textfiles.prepare_data(max_workers=1)


class Pretrain(Command):
    """Pretrain a next-token predicting mdoel."""

    config = parg(str, "the name of experiment configuration")
    from_task = karg(bool, "use data from specific reasoning task")
    resume = karg(bool, "resume previous experiment")
    resume_path = karg(str, "the checkpoint to resume")
    devices = karg(str, "the number of devices")
    nodes = karg(int, default=1)
    seed = karg(int, default=42)

    def __call__(self):

        config = get_config(self.config)

        if not self.from_task:
            data_path = config.get_pretrain_data_path
            pretrain_on_text(
                (data_path("train"), data_path("val")),
                config.model,
                config.out_path('tokenizer'),
                config.pretrain_args,
                out_dir=config.out_path('pretrain'),
                resume=_parse_resume(self.resume, self.resume_path),
                precision=config.precision,
                devices=_parse_devices(self.devices),
                num_nodes=self.nodes,
                seed=self.seed
            )
        else:
            reasoning = _get_reasoning_config(config)
            task = reasoning_task(reasoning.task)
            pretrain_on_reasoning_task(
                task,
                config.model,
                config.out_path('tokenizer'),
                config.pretrain_args,
                data=(task.data_options | reasoning.data_options),
                out_dir=config.out_path('pretrain'),
                resume=_parse_resume(self.resume, self.resume_path),
                precision=config.precision,
                devices=_parse_devices(self.devices),
                num_nodes=self.nodes,
                seed=self.seed
            )


class SFT(Command):
    """Supervised Fine-tuning that strengthens a model certain prompt-to-output prediction in
    reasoning. This requires a pretrained model."""

    config = parg(str, "the name of experiment configuration")
    startpoint = karg(str, "the checkpoint to start SFT", default="pretrain/final")
    from_file = karg(bool, "whether to load SFT data from existing json file.")
    resume = karg(bool, "resume previous experiment")
    resume_path = karg(str, "the checkpoint to resume SFT")
    devices = karg(str, "the number of devices")
    nodes = karg(int, default=1)
    seed = karg(int, default=1337)
    out = karg(str, "the output checkpoint", short='o')

    def __call__(self):

        config = get_config(self.config)
        reasoning = _get_reasoning_config(config)
        task = reasoning_task(reasoning.task)
        if self.from_file:
            task.prepare_data(
                files={"train": [".sft.json"], "val": [".sft.json"]},
                **reasoning.data_options
            )
        else:
            data_options = task.data_options | reasoning.data_options
        if reasoning.sft is None:
            print(f"config \"{config.name}\" does not support SFT.")
            exit()

        startpoint = self.startpoint
        
        sft_on_reasoning_task(
            task,
            config.out_path(startpoint),
            reasoning.sft,
            data="file" if self.from_file else data_options,
            out_dir=_out_dir(config, self.out or 'sft'),
            resume=_parse_resume(self.resume, self.resume_path),
            precision=config.precision,
            devices=_parse_devices(self.devices),
            num_nodes=self.nodes,
            seed=self.seed
        )


class RL(Command):
    """Run a RL algorithm to maximize the reward of a reasoning model."""

    config = parg(str, "the name of experiment configuration")
    alg = parg(str, "the algorithm to run", choices=("ppo", "grpo", "trainvf"))
    startpoint = karg(str, "the checkpoint to start RL", default="sft/final")
    reflective = karg(bool, "enable reflecive reasoning")
    resume = karg(bool, "resume previous experiment")
    resume_path = karg(str, "the checkpoint to resume SFT")
    devices = karg(str, "the number of devices")
    nodes = karg(int, default=1)
    out = karg(str, "the output checkpoint", short='o')

    def __call__(self):

        config = get_config(self.config)
        reasoning = _get_reasoning_config(config)
        task = reasoning_task(reasoning.task)
        
        startpoint = self.startpoint
        if startpoint is None:
            startpoint = "sft/final"
        startpoint = config.out_path(startpoint)

        reflection = None
        if self.reflective:
            if reasoning.train_refl is not None:
                reflection = reasoning.train_refl.rl
            if reflection is None:
                print("Warning: the reasoning configuration does not support RL training "
                      "for reflection. Therefore, reflection will not be enabled.")
        else:
            reflection = None

        approach: RLApproach
        if self.alg == "ppo":
            rlarg = reasoning.ppo
            approach = "ppo"
            out_dir=_out_dir(config, self.out or 'ppo')
        elif self.alg == "grpo":
            rlarg = reasoning.grpo
            approach = "grpo"
            out_dir=_out_dir(config, self.out or 'grpo')
        elif self.alg == "trainvf":
            rlarg = reasoning.train_vf
            approach = "estimate-value"
            out_dir = startpoint
            if self.out is not None:
                print(f"Warning: You have provided \"--out={self.out}\". This will not "
                      "have any effect, as `trainvf` does not reuiqre a output checkpoint.")
        else:
            assert False
        
        if rlarg is None:
            print(f"{config.name} does not support \"{approach}\" for reasoning.")
            exit()

        rl_on_reasoning_task(
            task,
            startpoint,
            reasoning.impl,
            approach,
            rlarg,
            reflection=reflection,
            precision=config.precision,
            num_nodes=self.nodes,
            devices=_parse_devices(self.devices),
            out_dir=out_dir,
            resume=_parse_resume(self.resume, self.resume_path)
        )
        

class Eval(Command):
    """Evaluate the model's reasoning performance in the task."""
    
    config = parg(str, "the name of experiment configuration")
    checkpoint = parg(str, "the checkpoint to evaluate")
    alg = karg(str, "the test-time execution algorithm")
    split = karg(str, "the data split used to evaluate", default="val")
    num = karg(int, "the maximal number of tested samples", short='n')
    verbose = karg(int, "the verbosity level", default=1)
    devices = karg(str, "the number of devices")
    litckpt_name = karg(str,
                        "load the model parameter from lightening .ckpt file instead of the default file. " \
                        "This can be the exact name or its unique prefix.")
    nodes = karg(int, default=1)
    seed = karg(int, default=1337)

    def __call__(self):

        config = get_config(self.config)
        reasoning = _get_reasoning_config(config)
        task = reasoning_task(reasoning.task)
        test_alg = self.alg

        checkpoint = self.checkpoint
        test = reasoning.get_test_args(test_alg)

        if test is None:
            print("The configuration does not support testing.")
            exit()
        
        reasoner_args = test.reasoner
        evaluator_args = test.evaluator
        inference_args = test.inference
        reflection_args = test.reflection

        checkpoint = config.out_path(checkpoint)
        if not checkpoint.exists():
            print(f"{checkpoint} does not exist.")
            exit()

        out_file_name = '.'.join(
            item
            for item in [self.alg, "eval", "json"]
            if item is not None
        )
        out_path = checkpoint/out_file_name

        evaluate_reasoning(
            checkpoint,
            task,
            reasoning.impl,
            test.context_size,
            test.batch_size,
            self.split,
            max_n=self.num,
            distribute=as_fabric(
                devices=_parse_devices(self.devices, default=1),
                precision=config.precision or "auto",
                num_nodes=self.nodes,
            ),
            save_results=out_path,
            evaluator_args=evaluator_args,
            reasoner_args=reasoner_args,
            inference_args=inference_args,
            reflection_args=reflection_args,
            verbose=self.verbose,
            litckpt_name=self.litckpt_name,
        )


class ReflData(Command):
    """Use a model to collect reasoning trajectories, which are used to generate reflective examples."""
    
    config = parg(str, "the name of experiment configuration")
    checkpoint = parg(str, "the checkpoint to evaluate")
    verbose = karg(int, "the verbosity level", default=1)
    append = karg(bool, "append new data to the existing file, if any.")
    devices = karg(str, "the number of devices")
    nodes = karg(int, default=1)
    seed = karg(int, default=1337)
    
    def __call__(self):

        config = get_config(self.config)
        reasoning = _get_reasoning_config(config)
        task = reasoning_task(reasoning.task)
        if (refl_args := reasoning.train_refl) is None:
            print("Arguments of collecting reflection data is not provided in "
                  "the reasoning configuration.")
            exit()
        checkpoint = config.out_path(self.checkpoint)
        out_path = checkpoint / "reflection_data"
        out_path.mkdir(exist_ok=True)

        collect_reflection_data(
            task, checkpoint, reasoning.impl,
            [
                {
                    "split": "train",
                    "n_samples": refl_args.n_train_samples,
                    "n_instances": refl_args.n_train_instances,
                },
                {
                    "split": "val",
                    "n_samples": refl_args.n_val_samples,
                    "n_instances": refl_args.n_val_instances,
                }
            ],
            devices=_parse_devices(self.devices, default=1),
            precision=config.precision,
            num_nodes=self.nodes,
            append_file=self.append,
            out_path=out_path,
            verbose=self.verbose,
            **refl_args.collect
        )


class ReflSFT(Command):
    """Perform reflective SFT on a model. Reflective data should be collected beforehand."""

    config = parg(str, "the name of experiment configuration")
    startpoint = parg(str, "the checkpoint to start SFT")
    datapath = karg(str, "the path of reflection data, by default \"{startpoint}/reflection_data\".", )
    resume = karg(bool, "resume previous experiment")
    resume_path = karg(str, "the checkpoint to resume SFT")
    devices = karg(str, "the number of devices")
    nodes = karg(int, default=1)
    seed = karg(int, default=1337)
    out = karg(str, "the output checkpoint, by default \"{startpoint}+refl\"", short='o')

    def __call__(self):

        config = get_config(self.config)
        reasoning = _get_reasoning_config(config)
        startpoint = self.startpoint
        if (refl_args := reasoning.train_refl) is None:
            print("Arguments of collecting reflection data is not provided in "
                "the reasoning configuration.")
            exit()
        sft_args = refl_args.sft or reasoning.sft
        if sft_args is None:
            print(f"config \"{config.name}\" does not support SFT.")
            exit()
        checkpoint = config.out_path(startpoint)
        out_dir=config.out_path(self.out) if self.out else None
        sft_for_reflection(
            refl_args.approach,
            checkpoint ,
            sft_args,
            reasoning.task,
            data_args=refl_args.sft_data,
            data_path=self.datapath,
            out_dir=out_dir,
            resume=_parse_resume(self.resume, self.resume_path),
            precision=config.precision,
            devices=_parse_devices(self.devices),
            num_nodes=self.nodes,
            seed=self.seed
        )


class ReflEval(Command):

    config = parg(str, "the name of experiment configuration")
    checkpoint = parg(str, "the checkpoint to start SFT")
    refldata = parg(str, "the path of reflection data")

    def __call__(self):
        from core.reasoning.reflection.self_verify import evaluate_reflection_accuracy, SVReflSFTData, Symbol

        config = get_config(self.config)
        reasoning = _get_reasoning_config(config)

        llm = make_llm(config.out_path(self.checkpoint))
        task = reasoning_task(reasoning.task)
        data = SVReflSFTData(llm.tokenizer, label_type="binary", task=task)

        data.load_file(config.out_path(self.refldata, "val.json"))

        print("accept:", data.type.count(Symbol.accept))
        print("reject:", data.type.count(Symbol.reject))
        evaluate_reflection_accuracy(llm, data, 1024, batch_size=128, verbose=2, sampling_args=dict(temperature=0))


class Debug(Command):

    def __call__(self):

        import torch
        from pprint import pprint
        from typing import Iterable
        from core.reasoning.reasoners.markov_chain import ThoughtSteps
        from core.rl import Collector
        from core.reasoning import CoT
        from core.reasoning.reflection.self_verify import SVRewardWrapper
        from random import choices

        torch.set_printoptions(precision=2, sci_mode=False, linewidth=100)

        llm = make_llm('out/mult-4m/sft-r1-debug/final')
        task = reasoning_task('mult-mdp-v1')
        reflection_args = param.reasoning.reflection.self_verify(64, 1024, "revise", enable_statistics=True)
        r = make_reasoner(
            llm, "mdp", 2048,
            task=task,
            inference_args=param.inference.sampling(),
            reflection_args=reflection_args,
            max_steps=32,
            require_thought=True
        )

        # reference function
        def ref(idx: tuple[int, ...]):
            i = idx[0]
            return task.get_ref_dict(cots[i])
        
        r.ref = ref

        # prepare trajectory collector
        collector = Collector('cpu',
                              Collector.Require(probs=True, rewards=True),
                              allow_abortion=False)
        reward_model = task.reward_model("outcome")
        reward_model = SVRewardWrapper(reward_model)
        collector.reward_model = reward_model
        collector.attach(r)
        
        # prepare evaluator
        evaluator = task.evaluator()
        evaluator.attach(r)

        # prepare samples
        all_cots = task.get_instances('train')
        cot_batches: Iterable[list[CoT]] # the input cases
        cot_batches = [
            choices(all_cots, k=4),
            choices(all_cots, k=4),
        ]

        for cots in cot_batches:

            text_inputs: list[str] = []
            ground_truths: list = []

            for cot in cots:
                text_inputs.append(task.style.apply_input(**cot.as_dict()))
                ground_truths.append(cot.outcome)

            if len(text_inputs) == 0:
                break

            with torch.inference_mode():
                input_tokens = r.preprocess(text_inputs)
                thought, out = r(input_tokens)

            cot: CoT = CoT(input_tokens, thought, out)

            if isinstance(thought, ThoughtSteps):
                for i, step in enumerate(thought):
                    print(i, llm.detokenize(step), sep='\n')\

            print(r.detokenize(cot))
            print('ground truth:', ground_truths, sep='\n')

        data = collector.data
        print(data)
        for i in range(len(data)):
            batch = data[i]
            tokens: torch.Tensor = batch.pop('tokens')
            length: torch.Tensor = batch.pop('lengths')
            prompt_length = batch.pop('prompt_lengths')
            probs =  batch.pop('probs', None)

            prompt = llm.preprocessor.decode(tokens[:prompt_length], False)
            output = llm.preprocessor.decode(tokens[prompt_length:length], False)
            if probs is not None:
                probs = probs[prompt_length:length]
            print('-----------')
            pprint({"prompt": prompt,
                    "output": output,
                    # "feedback": checker(prompt, output),
                    "data": batch,
                    "probs": probs}, sort_dicts=False)

        pprint(evaluator.get_results())


def get_command(name: str) -> type[Command]:
    commands = {k.lower(): v for k, v in globals().items()
                if isinstance(v, type) and issubclass(v, Command) and v is not Command}
    try:
        return commands[name.lower()]
    except KeyError:
        supported = ', '.join(("\"%s\"" % k.removeprefix('do_')) for k in commands.keys())
        print(f"{name} is not a supported command to run. "
              f"The supported commands are {supported}.")
        exit()


if __name__ == "__main__":

    _run(globals())
