import os as _os
import typing as _t
import json as _json
from dataclasses import dataclass as _dataclass
from dataclasses import field as _field
from pathlib import Path

from litgpt.args import EvalArgs, TrainArgs

import core.model
import core.pretrain
import core.reasoning
import core.rl
import core.sft
import core.tokenization

from core.tokenization import Tokenizer
from core.inference import make_inference, Approach as InferenceApproach
from core.evaluate import evaluate_reasoning
from core.model import (
    Auto, Logger, Precision, make_llm, Distrubute, as_fabric,
    Config as ModelConfig,
)
from core.reasoning import (
    ReasoningTask,
    task as _task,
    ThoughtImpl,
    make_reasoner,
    collect_reflection_data,
)
from core.reasoning.reflection import (
    Approach as ReflectionApproach,
    make_refletion_datamodule,
    get_reflection_class,
    ReflectiveSample as _ReflectiveSample
)
from core.rl.ppo import LitPPO as PPOModel
from core.utils import Progress, TextFiles


Supervision = _t.Literal['outcome', 'process', 'success']
type _PathLike = str | Path
type _Split = _t.Literal['train', 'test', 'val'] | str


def _as_path[T](path: _PathLike | None, default: T = None) -> Path | T:
    if path is None:
        return default
    else:
        return Path(path) if not isinstance(path, Path) else path


@_dataclass
class NTPHyperParames:
    """
    Hyper parameters for Next-Token Prediction.
    """

    train: TrainArgs
    eval: EvalArgs
    optimizer: str | dict


reasoning_task = _task.task_from_name
reasoning_task_available = _task.is_available_task_name
available_reasoning_tasks = _task.available_task_names


def pretrain_on_reasoning_task(
    task: ReasoningTask | str,
    model: str | ModelConfig,
    tokenizer_dir: _PathLike,
    hparams: NTPHyperParames,
    data: _t.Literal["file"] | dict | None = None,
    init_checkpoint: _PathLike | None = None,
    out_dir: _PathLike | None = None,
    compile_model: bool = False,
    precision: Precision | None = None,
    resume: bool | Auto | Path = False,
    devices: int | str = "auto",
    num_nodes: int = 1,
    logger_name: Logger = "tensorboard",
    seed: int = 42,
):

    model_name = model.name if isinstance(model, ModelConfig) else model
    model_name = model_name or "nomane"
    model_config = model if isinstance(model, ModelConfig) else {}
    
    default_outdir = Path(f'out/{model_name}/pretrain')
    out_dir = _as_path(out_dir, default=default_outdir)
    if isinstance(task, str):
        task = reasoning_task(task)
    
    if data == "file":
        data_module = task.data_module("pretrain", seed=seed, from_file=True)
    else:
        data = data or {}
        data_module = task.data_module("pretrain", seed=seed, **data)
    core.pretrain.setup(
        model_name,
        model_config=model_config,
        data=data_module,
        data_mode='sft',
        train=hparams.train,
        eval=hparams.eval,
        initial_checkpoint_dir=_as_path(init_checkpoint),
        tokenizer_dir=_as_path(tokenizer_dir),
        out_dir=out_dir,
        compile_model=compile_model,
        precision=precision,
        resume=resume,
        optimizer=hparams.optimizer,
        devices=devices,
        num_nodes=num_nodes,
        logger_name=logger_name,
        seed=seed,
    )


def pretrain_on_text(
    data_path: _PathLike | tuple[_PathLike, _PathLike],
    model: str | ModelConfig,
    tokenizer_dir: _PathLike,
    hparams: NTPHyperParames,
    seq_separator: str | tuple[str, ...] | None = None,
    seq_drop_size: int | None = None,
    init_checkpoint: _PathLike | None = None,
    out_dir: _PathLike | None = None,
    compile_model: bool = False,
    precision: Precision | None = None,
    resume: bool | Auto | Path = False,
    devices: int | str = "auto",
    num_nodes: int = 1,
    logger_name: Logger = "tensorboard",
    seed: int = 42,
):

    model_name = model.name if isinstance(model, ModelConfig) else model
    model_name = model_name or "nomane"
    model_config = model if isinstance(model, ModelConfig) else {}
    default_outdir = Path(f'out/pretrain/{model_name}')
    out_dir = _as_path(out_dir, default=default_outdir)
    train, val = data_path if isinstance(data_path, tuple) else (data_path, None)
    data_module = TextFiles(
        Path(train), _as_path(val), seed=seed, num_workers=4,
        sep=seq_separator, drop_size=seq_drop_size
    )
    
    # tokenizer = core.tokenizer.Tokenizer(tokenizer_dir)
    # data_module.connect(tokenizer)
    # data_module.prepare_data()
    # del tokenizer

    core.pretrain.setup(
        model_name,
        model_config=model_config,
        data=data_module,
        data_mode='pretrain',
        train=hparams.train,
        eval=hparams.eval,
        initial_checkpoint_dir=_as_path(init_checkpoint),
        tokenizer_dir=_as_path(tokenizer_dir),
        out_dir=out_dir,
        compile_model=compile_model,
        precision=precision,
        resume=resume,
        optimizer=hparams.optimizer,
        devices=devices,
        num_nodes=num_nodes,
        logger_name=logger_name,
        seed=seed,
    )


def sft_on_reasoning_task(
    task: ReasoningTask | str,
    checkpoint: _PathLike,
    hparams: NTPHyperParames,
    data: _t.Literal['file'] | dict | None = None,
    out_dir: _PathLike | None = None,
    precision: Precision | None = None,
    resume: bool | Auto | Path = False,
    devices: int | str = "auto",
    num_nodes: int = 1,
    logger_name: _t.Literal["wandb", "tensorboard", "csv"] = "csv",
    seed: int = 1337,
    access_token: str | None = None,
):

    if isinstance(task, str):
        task = reasoning_task(task)
    
    checkpoint = Path(checkpoint)
    default_outdir = checkpoint.parent / (checkpoint.name + "+sft")
    out_dir=_as_path(out_dir, default_outdir)
    if data == "file":
        data_module = task.data_module("sft", seed=seed, from_file=True)
    else:
        data = data or {}
        data_module = task.data_module("sft", seed=seed, **data)

    core.sft.setup(
        checkpoint,
        out_dir=_as_path(out_dir, default_outdir),
        precision=precision,
        devices=devices,
        num_nodes=num_nodes,
        resume=resume,
        data=data_module,
        train=hparams.train,
        eval=hparams.eval,
        optimizer=hparams.optimizer,
        logger_name=logger_name,
        seed=seed,
        access_token=access_token,
    )


RLApproach = _t.Literal["ppo", "grpo", "estimate-value"]


def rl_on_reasoning_task(
    task: ReasoningTask | str,
    checkpoint: _PathLike,
    impl: ThoughtImpl,
    approach: RLApproach,
    args: core.rl.RLArgs,
    reflection: dict | None = None,
    precision: Precision | None = None,
    devices: int | str = "auto",
    num_nodes: int = 1,
    out_dir: _PathLike | None = None,
    resume: Auto | Path | bool = False,
):
    if isinstance(task, str):
        task = reasoning_task(task)
    
    if approach == 'ppo':
        if not isinstance(args, core.rl.PPOArgs):
            raise TypeError("`args` is not an instance of `PPOArgs`.")
        out_dir = _as_path(out_dir, Path(f"out/{approach}/"))
        rl = core.rl.PPO(
            task=task,
            checkpoint=str(checkpoint),
            args=args,
            impl=impl,
            reflection=reflection,
            out_checkpoint=str(out_dir),
        )
    elif approach == 'estimate-value':
        if not isinstance(args, core.rl.ValueIterArgs):
            raise TypeError
        out_dir = _as_path(out_dir, default=checkpoint)
        rl = core.rl.ValueEstimation(
            task=task,
            checkpoint=str(checkpoint),
            args=args,
            impl=impl,
            reflection=reflection,
            out_checkpoint=str(out_dir),
        )
    elif approach == "grpo":
        if not isinstance(args, core.rl.GRPOArgs):
            raise TypeError
        out_dir = _as_path(out_dir, default=f"out/{approach}/")
        rl = core.rl.GRPO(
            task=task,
            checkpoint=str(checkpoint),
            args=args,
            impl=impl,
            reflection=reflection,
            out_checkpoint=str(out_dir),
        )
    else:
        raise NotImplementedError(approach)

    rl.save()
    rl.fit(
        precision=precision,
        devices=devices,
        num_nodes=num_nodes,
        resume=resume,
    )
    rl.save()


train_tokenizer = core.tokenization.train_tokenizer


class TextFileGroup(_t.TypedDict):

    prefix: str
    root: _PathLike
    files: dict[_Split, list[_PathLike]]
    token_map: dict[str, str]
    encoding: str
    

def setup_text_files(
    target_dir: _PathLike,
    *files: TextFileGroup,
    global_token_map: dict[str, str] = {},
    encoding = 'utf-8',
    force_write: bool = False,
):
    target_dir = Path(target_dir)

    for group in files:

        root = Path(group['root'])

        token_map = global_token_map.copy()
        token_map.update(group['token_map'])

        for split, file_paths in group['files'].items():
            _os.makedirs(target_dir / split, exist_ok=True)

            for i, file_path in enumerate(file_paths):
                
                file_path = root / file_path
                name = group['prefix'] + '_' + split + '_' + str(i)
                targ_path = (target_dir / split / name).with_suffix('.txt')

                if not force_write and targ_path.exists():
                    print(f"Skip {file_path} since {targ_path} already exists.")
                    continue

                with open(file_path, 'rt', encoding=group['encoding']) as f:
                    print(f'processing: {file_path}', end=' ')
                    text = f.read()
                for src, dst in token_map.items():
                    text = text.replace(src, dst)

                with open(targ_path, 'wt', encoding=encoding) as f:
                    f.write(text)
                    print(f'(saved to {targ_path})')

    print("All operations have terminated.")


def sft_for_reflection(
    approach: ReflectionApproach,
    checkpoint: _PathLike,
    hparams: NTPHyperParames,
    task: ReasoningTask | str | None = None,
    data_args: dict = {},
    data_path: _PathLike | None = None,
    out_dir: _PathLike | None = None,
    precision: Precision | None = None,
    resume: bool | Auto | Path = False,
    devices: int | str = "auto",
    num_nodes: int = 1,
    logger_name: _t.Literal["wandb", "tensorboard", "csv"] = "csv",
    seed: int = 1337,
    access_token: str | None = None,
):
    checkpoint = Path(checkpoint)
    if isinstance(task, str): task = reasoning_task(task)

    if checkpoint.name == "final":
        temp = checkpoint.parent
        default_outdir = temp.parent / (temp.name + "+refl")
    else:
        default_outdir = checkpoint.parent / (checkpoint.name + "+refl")
    
    out_dir = _as_path(out_dir, default_outdir)
    data_path = _as_path(data_path, checkpoint / "reflection_data")
    data_module = make_refletion_datamodule(approach, data_path, task, **data_args)
    core.sft.setup(
        checkpoint,
        out_dir=_as_path(out_dir, default_outdir),
        precision=precision,
        devices=devices,
        num_nodes=num_nodes,
        resume=resume,
        data=data_module,
        train=hparams.train,
        eval=hparams.eval,
        optimizer=hparams.optimizer,
        logger_name=logger_name,
        seed=seed,
        access_token=access_token,
    )


class param:

    @staticmethod
    def model(
        name: str = "",
        *,
        vocab_size: int,
        block_size: int,
        n_layer: int,
        n_embd: int,
        n_head: int,
        padding_multiple: int = 128,
    ):
        return ModelConfig(
            name=name,
            vocab_size=vocab_size,
            block_size=block_size,
            n_layer=n_layer,
            n_embd=n_embd,
            n_head=n_head,
            padding_multiple=padding_multiple,
        )

    @staticmethod
    def pretrain(
        max_tokens: int,
        save_interval: int | None = 1000,
        log_interval: int = 1,
        batch_size: int | _t.Sequence[int] = (64, 16),
        lr_warmup: int | str | float = 100,
        max_seq_length: int | None = None,
        tie_embeddings: bool | None = None,
        max_norm: float = 1.,
        min_lr: float = 6e-5,
        eval_interval: int = 1000,
        eval_iters: int = 100,
        optimizer: str | dict = "AdamW",
    ) -> NTPHyperParames:
        if isinstance(batch_size, int):
            global_batch_size = micro_batch_size = batch_size
        else:
            if not len(batch_size) == 2:
                raise ValueError("batch size must be one integar or a sequence of two integars.")
            global_batch_size, micro_batch_size = batch_size
        
        if isinstance(lr_warmup, str):
            if lr_warmup.endswith('%'):
                lr_warmup = float(lr_warmup.rstrip('%')) / 100
            elif '.' in lr_warmup:
                lr_warmup = float(lr_warmup)
            else:
                lr_warmup = int(lr_warmup)
        if isinstance(lr_warmup, float):
            if not (0. <= lr_warmup <= 1.):
                raise ValueError(f"lr_warmup accepts a float value within [0, 1], but encountered {lr_warmup}.")
            if lr_warmup == 1.0:
                print("[Warning] Using lr_warmup = 1.0, which means a whole-process warmup. Do you mean only one step?")
        
        trainargs = TrainArgs(
            save_interval=save_interval,
            log_interval=log_interval,
            global_batch_size=global_batch_size,
            micro_batch_size=micro_batch_size,
            lr_warmup_steps=lr_warmup if isinstance(lr_warmup, int) else None,
            lr_warmup_fraction=lr_warmup if isinstance(lr_warmup, float) else None,
            max_tokens=max_tokens,
            max_seq_length=max_seq_length,
            tie_embeddings=tie_embeddings,
            max_norm=max_norm,
            min_lr=min_lr,
        )

        evalargs = EvalArgs(interval=eval_interval, max_iters=eval_iters)

        return NTPHyperParames(trainargs, evalargs, optimizer)

    @staticmethod
    def sft(
        epochs: int,
        save_interval: int | None = 1000,
        log_interval: int = 1,
        batch_size: int | _t.Sequence[int] = (64, 16),
        lr_warmup: int | None = 100,
        max_seq_length: int | None = None,
        min_lr: float = 6e-5,
        eval_interval: int = 1000,
        eval_iters: int = 100,
        eval_max_new_tokens: int = 100,
        optimizer: str | dict = "AdamW",
    ) -> NTPHyperParames:
        if isinstance(batch_size, int):
            global_batch_size = micro_batch_size = batch_size
        else:
            if not len(batch_size) == 2:
                raise ValueError("batch size must be one integar or a sequence of two integars.")
            global_batch_size, micro_batch_size = batch_size
        
        trainargs = TrainArgs(
            save_interval=save_interval,
            log_interval=log_interval,
            global_batch_size=global_batch_size,
            micro_batch_size=micro_batch_size,
            lr_warmup_steps=lr_warmup if isinstance(lr_warmup, int) else None,
            epochs=epochs,
            max_seq_length=max_seq_length,
            min_lr=min_lr,
        )

        evalargs = EvalArgs(
            interval=eval_interval,
            max_iters=eval_iters,
            max_new_tokens=eval_max_new_tokens
        )

        return NTPHyperParames(trainargs, evalargs, optimizer)

    class rl:
        ppo = core.rl.PPOArgs
        grpo = core.rl.GRPOArgs
        train_vf = core.rl.ValueIterArgs
        reflective = core.rl.param_for_refl

    optim_lit = core.model.LitLLM.OptimArgs
    vocabulary = core.tokenization.Vocabulary

    @staticmethod
    def lr_progress(
        target_coef: float,
        stop: int,
        start: int = 0,
        func: Progress._Supported = 'linear'
    ):
        return Progress(1., target_coef, start, stop, func)
    
    @staticmethod
    def __kwargs(**kwargs):
        return {k: v for k, v in kwargs.items() if v is not None}

    @staticmethod
    def BPE(
        vocab_size: int | None = None,
        min_frequency: int | None = None,
        end_of_word_suffix: str|None = '</w>',
        continuing_subword_prefix: str|None = None,
    ):
        return dict(
            model= param.__kwargs(
                end_of_word_suffix=end_of_word_suffix,
                continuing_subword_prefix=continuing_subword_prefix
            ),
            trainer=param.__kwargs(
                vocab_size=vocab_size,
                min_frequency=min_frequency,
                end_of_word_suffix=end_of_word_suffix,
                continuing_subword_prefix=continuing_subword_prefix,
            ),
            decoder=param.__kwargs(suffix=end_of_word_suffix),
            pre_tokenizer=dict(),
        )
    
    from core.inference import param as inference
    from core.reasoning import param as reasoning
