from functools import partial
from dataclasses import dataclass

from tasks import Sudoku, register_task
from core.api import param, Progress
from ._base import Config, BILLION
from ._testing_arg_builder import BaseTestingArgBuilder


# ============================
# Register Tasks
# ============================

type SizeMap = dict[int, int]

split_sizes: dict[str, SizeMap] = {
    "train": {x: 800 for x in range(9, 54)},
    "val": {x: 50 for x in range(9, 54)},
    "test": {x: 50 for x in range(9, 63)},
}
split_sizes_direct: dict[str, SizeMap] = {
    "train": {x: 2000 for x in range(9, 54)},
    "val": {x: 100 for x in range(9, 54)},
    "test": {x: 100 for x in range(9, 63)},
}

register_task('sudoku', Sudoku)
register_task('sudoku-mdp-v0',
               partial(Sudoku, root='data/sudoku-mdp-v0', implementation="mdp-v0",
                       split_sizes=split_sizes, max_solution_depth=32, max_solving_time=1))
register_task('sudoku-direct',
               partial(Sudoku, root='data/sudoku-direct', implementation=None,
                       split_sizes=split_sizes_direct))


def difficulty(blanks: int):
    if blanks < 9:
        return "Easy OOD"
    elif 9 <= blanks < 36:
        return "Easy ID"
    elif blanks < 54:
        return "Hard ID"
    else:
        return "Hard OOD"


# ------------ testing args ----------- #

@dataclass
class SudokuTestingArgBuilder(BaseTestingArgBuilder):

    batch_size: int
    context_size: int
    max_steps: int | None = None
    refl_budget: int | None = None

    def arg(self, decode, reflect, traceback) -> Config.TestArgs:
        reasoner = dict()
        if self.max_steps:
            reasoner.update(max_steps=self.max_steps)
        if reflect == "random":
            random_reject_ratio = self.key(decode, "self", traceback) + ".eval.json"
        else:
            random_reject_ratio = 0
        return Config.TestArgs(
            batch_size=self.batch_size,
            context_size=self.context_size,
            inference=param.inference.sampling(temperature=(0 if decode=="greedy" else 1)),
            evaluator={"rounding": 3},
            reasoner=reasoner,
            reflection=param.reasoning.reflection.self_verify(
                budget=self.refl_budget,
                max_retry=(4 if traceback else None),
                external_verifier=(None if reflect == "self" else reflect),
                revise_temperature=(0 if reflect == "random" else 1),
                reflect_temperature=0,
                enable_statistics=True,
                random_reject_ratio=random_reject_ratio,
            ) if reflect else None,   
        )


# ------------- 4m -------------- #

sudoku_4m = Config(
    "sudoku-4m",
    model=param.model(
        name="sudoku-4m",
        vocab_size=128,
        block_size=2048,
        n_layer=5,
        n_embd=256,
        n_head=8,
        padding_multiple=128,
    ),
    corpus=[dict(name='sudoku-mdp-v0')],
    pretrain_args=param.pretrain(
        max_tokens=1*BILLION,
        batch_size=128,
    ),
    special_tokens=param.vocabulary(pad='<|_|>', eos='<|EOS|>', unk='<|?|>'),
    reasoning=Config.Reasoning(
        task='sudoku-mdp-v0',
        impl='mdp',
        sft=param.sft(
            epochs=5,
            batch_size=128,
            max_seq_length=1024,
            eval_max_new_tokens=512,
            save_interval=10000,
        ),
        ppo=param.rl.ppo(
            epochs=512,
            train_epoch_size=128,
            train_epoch_repeat=4,
            train_batch_size=64,
            inference_batch_size=512,
            context_length=512,
            supervision="outcome",
            source_of_inputs="data",
            enable_abortion=True,
            temperature=1.25,
            max_steps=32,
            optim=param.optim_lit(
                optimizer_args=dict(lr=5e-5),
                scheduler=param.lr_progress(0.2, 512)
            ),
            trainer_args=dict(accumulate_grad_batches=2),
            save_each_n_epochs=16,
            save_top_k=4,
            val_each_n_epochs=4,
            val_temperature=1,
            max_val_size=512,
            val_scale=1,
            collect_shape=4,
            clip=0.1,
            discount=1,
            target_method='td',
            ignore_truncated=True,
            prompt_value_weight=0.5,
            epochs_warmup=64,
        ),
        grpo=param.rl.grpo(
            epochs=512,
            train_epoch_size=128,
            train_epoch_repeat=4,
            train_batch_size=64,
            inference_batch_size=1024,
            context_length=512,
            supervision="outcome",
            source_of_inputs="data",
            enable_abortion=True,
            temperature=1.25,
            max_steps=32,
            optim=param.optim_lit(
                optimizer_args=dict(lr=5e-5),
                scheduler=param.lr_progress(0.2, 512)
            ),
            trainer_args=dict(accumulate_grad_batches=2),
            save_each_n_epochs=16,
            save_top_k=4,
            val_each_n_epochs=4,
            val_temperature=1,
            max_val_size=512,
            val_scale=1,
            group_size=8,
            clip=0.1,
        ),
        train_refl=Config.ReflectiveTraining(
            "self-verify",
            collect=param.reasoning.reflection.collect(
                context_length=512,
                batch_size=512,
                n_branches=8,
                max_steps=16,
                propose_temperature=1.25,
                solve_temperature=0.75,
                rollout_temperature=0.5,
                evaluation=param.reasoning.reflection.eval("process", rollout_length=0)
            ),
            n_train_samples=200000,
            n_val_samples=20000,
            sft_data=param.reasoning.reflection.self_verify_data(
                non_reflective_data="task",
                reflection_label="detailed",
                reflection_frequency=1,
                use_weighted_sampler=True,
            ),
            sft=param.sft(
                epochs=3,
                batch_size=128,
                max_seq_length=512,
                eval_max_new_tokens=512,
                save_interval=10000,
            ),
            rl=param.rl.reflective.revise_error(reflect_temperature=1)
        ),
        test=SudokuTestingArgBuilder(256, 512, 32, 64).build(True),
    ),
    precision='bf16-true',
)


# --- configure hyper parameters --- #
# --------------- 1M --------------- #
# ---------------------------------- #


sudoku_1m = sudoku_4m.derive("sudoku-1m")
sudoku_1m.model = param.model(
    name="sudoku-1m",
    vocab_size=128,
    block_size=2048,
    n_layer=5,
    n_embd=128,
    n_head=4,
    padding_multiple=128,
)
assert sudoku_1m.reasoning is not None
assert sudoku_1m.reasoning.train_refl is not None
sudoku_1m.reasoning.sft=param.sft(
    epochs=5,
    batch_size=128,
    max_seq_length=1024,
    eval_max_new_tokens=512,
    save_interval=10000,
)
sudoku_1m.reasoning.train_refl.sft = param.sft(
    epochs=5,
    batch_size=128,
    max_seq_length=1024,
    eval_max_new_tokens=512,
    save_interval=10000,
)
assert sudoku_1m.reasoning.grpo is not None
assert sudoku_1m.reasoning.ppo is not None
sudoku_1m.reasoning.grpo.temperature = 1.
sudoku_1m.reasoning.ppo.temperature = 1.
sudoku_1m.reasoning.train_refl.collect = param.reasoning.reflection.collect(
    context_length=512,
    batch_size=512,
    n_branches=8,
    max_steps=16,
    propose_temperature=1.0,
    solve_temperature=0.75,
    rollout_temperature=0.5,
    evaluation=param.reasoning.reflection.eval("process", rollout_length=0)
)


# --- configure hyper parameters --- #
# -------------- 16M --------------- #
# ---------------------------------- #

sudoku_16m = sudoku_4m.derive("sudoku-16m")
sudoku_16m.model = param.model(
    name="sudoku-16m",
    vocab_size=128,
    block_size=2048,
    n_layer=5,
    n_embd=512,
    n_head=8,
    padding_multiple=128,
)
assert isinstance(sudoku_16m.reasoning, Config.Reasoning)
assert sudoku_16m.reasoning.train_refl is not None
assert isinstance(sudoku_16m.reasoning.test, dict)
sudoku_16m.reasoning.sft=param.sft(
    epochs=5,
    batch_size=(128, 32),
    max_seq_length=1024,
    eval_max_new_tokens=512,
    save_interval=10000,
)
sudoku_16m.reasoning.train_refl.sft = param.sft(
    epochs=3,
    batch_size=(128, 32),
    max_seq_length=1024,
    eval_max_new_tokens=512,
    save_interval=10000,
)
sudoku_16m.pretrain_args = param.pretrain(
    max_tokens=1*BILLION,
    batch_size=(128, 32),
)
sudoku_16m.reasoning.sft = param.sft(
    epochs=5,
    batch_size=(128, 32),
    max_seq_length=1024,
    eval_max_new_tokens=512,
    save_interval=10000,
)
sudoku_16m.reasoning.test = SudokuTestingArgBuilder(128, 512, 32, 64).build(True)
sudoku_16m.reasoning.train_refl.collect = param.reasoning.reflection.collect(
    context_length=512,
    batch_size=512,
    n_branches=8,
    max_steps=16,
    propose_temperature=1.5,
    solve_temperature=0.75,
    rollout_temperature=0.5,
    evaluation=param.reasoning.reflection.eval("process", rollout_length=0)
)
assert sudoku_16m.reasoning.grpo is not None
sudoku_16m.reasoning.grpo.inference_batch_size = 256
sudoku_16m.reasoning.grpo.temperature = 1.25
assert sudoku_16m.reasoning.ppo is not None
sudoku_16m.reasoning.ppo.inference_batch_size = 256
sudoku_16m.reasoning.ppo.temperature = 1.25


# ----------- direct ---------- #

for _cfg, _sft_batchsize in zip(
    [sudoku_1m, sudoku_4m, sudoku_16m],
    [256, 256, (256, 128)],
):
    _cfg_d = _cfg.derive(_cfg.name + '-direct')
    _cfg_d.corpus = [dict(name='sudoku-direct')]
    _cfg_d.reasoning = Config.Reasoning(
        task='sudoku-direct',
        impl='tokenwise',
        sft=param.sft(
            epochs=10,
            batch_size=256,
            max_seq_length=256,
            eval_max_new_tokens=512,
            save_interval=10000,
        ),
        test=SudokuTestingArgBuilder(512, 256).build(False),
    )


# ------ concatenated state ------ #

for _cfg, _sft_batchsize in zip(
    [sudoku_1m, sudoku_4m, sudoku_16m],
    [128, 128, (128, 64)],
):
    _cfg_cat = _cfg.derive(name=(_cfg.name + "-concat"))
    _cfg_cat.reasoning = Config.Reasoning(
        task='sudoku-mdp-v0',
        impl='tokenwise',
        data_options=dict(
            token_level = True,
            step_level = False,
            state_level = False,
            policy_level = False,
            predict_outcome = False,
            predict_init = False,
        ),
        sft=param.sft(
            epochs=5,
            batch_size=_sft_batchsize,
            max_seq_length=_cfg_cat.model.block_size,
            eval_max_new_tokens=512,
            save_interval=10000,
        ),
        test=SudokuTestingArgBuilder(256, 512).build(False),
    )
