from .variant_trainers import *

def validate_cfg(cfg: dict) -> bool:
    required_keys = [
        "use_fkl",
        "use_kl",
        "use_random",
        "use_token",
        "use_branch",
    ]
    for k in required_keys:
        if k not in cfg or not isinstance(cfg[k], bool):
            return False

    use_fkl = cfg["use_fkl"]
    use_kl = cfg["use_kl"]
    use_random = cfg["use_random"]
    use_token = cfg["use_token"]
    use_branch = cfg["use_branch"]

    if use_fkl:
        if use_kl or use_random or use_token or use_branch:
            return False

    if (use_random or use_token or use_branch) and not use_kl:
        return False

    if (use_random + use_token + use_branch) > 1:
        return False

    return True


def get_trainer(model, tokenizer, training_args, dataset ,reward, cfg):
    if not validate_cfg(cfg):
        print("Invalid configuration combination. Automatically defaulting to GRPOTrainer.")
        trainer = GRPOTrainer(
        model = model,
        processing_class = tokenizer,
        reward_funcs = [
            reward
        ],
        args = training_args,
        train_dataset = dataset,

        # For optional training + evaluation
        # train_dataset = new_dataset["train"],
        # eval_dataset = new_dataset["test"],
        )
    else:
        MAX_STEPS = cfg["rl"]["max_steps"]
        if cfg["use_fkl"]:
            trainer = FKLTrainer(
            model = model,
            processing_class = tokenizer,
            reward_funcs = [
                reward
            ],
            args = training_args,
            train_dataset = dataset,

            # For optional training + evaluation
            # train_dataset = new_dataset["train"],
            # eval_dataset = new_dataset["test"],
            )
        elif cfg["use_branch"]:
            compressor = EntropyCompressor(ratio= cfg["rl"]["branch"]["ratio"],
                                        threshold= cfg["rl"]["branch"]["threshold"],
                                        use_source= cfg["rl"]["branch"]["use_source"],
                                        q_min= cfg["rl"]["q_min"],
                                        q_max= cfg["rl"]["q_max"]
                                    )
            trainer = BranchTrainer(compressor= compressor,
                                model = model,
                                processing_class = tokenizer,
                                reward_funcs = [reward],
                                args = training_args,
                                train_dataset = dataset
                                )
        elif cfg["use_random"]:
            scheduler = CosineAnnealingScheduler(max_steps = MAX_STEPS, 
                                                init_eps = cfg["rl"]["random"]["eps_ub"], 
                                                lb_eps = cfg["rl"]["random"]["eps_lb"], 
                                                init_sig = cfg["rl"]["random"]["sig_ub"], 
                                                lb_sig = cfg["rl"]["random"]["sig_lb"], 
                                                N = cfg["rl"]["random"]["N"], 
                                                decay_rate = cfg["rl"]["random"]["decay_rate"]
                                                )
            sampler = Sampler(scheduler= scheduler,
                            q_min= cfg["rl"]["q_min"],
                            q_max= cfg["rl"]["q_max"]
                            )
            trainer = RandomTrainer(
                sampler = sampler,
                model = model,
                processing_class = tokenizer,
                reward_funcs = [
                    reward
                ],
                args = training_args,
                train_dataset = dataset,

                # For optional training + evaluation
                # train_dataset = new_dataset["train"],
                # eval_dataset = new_dataset["test"],
                )
        elif cfg["use_token"]:
            scheduler = CosineAnnealingScheduler(max_steps = MAX_STEPS, 
                                                init_eps = cfg["token"]["alpha_ub"], 
                                                lb_eps = cfg["token"]["alpha_lb"], 
                                                init_sig = cfg["token"]["sig_ub"], 
                                                lb_sig = cfg["token"]["sig_lb"], 
                                                N = cfg["token"]["N"], 
                                                decay_rate= cfg["token"]["decay_rate"]
                                                )
            sampler = SurprisalAwareSampler(scheduler = scheduler,
                                            use_source= cfg["token"]["use_source"], 
                                            norm= cfg["token"]["norm"], 
                                            q_min= cfg["token"]["q_min"],
                                            q_max= cfg["token"]["q_max"]
                                            )
            trainer = TokenTrainer(
                sampler = sampler,
                model = model,
                processing_class = tokenizer,
                reward_funcs = [
                    reward
                ],
                args = training_args,
                train_dataset = dataset

                # For optional training + evaluation
                # train_dataset = new_dataset["train"],
                # eval_dataset = new_data
            )
        else:
            trainer = GRPOTrainer(
            model = model,
            processing_class = tokenizer,
            reward_funcs = [
                reward
            ],
            args = training_args,
            train_dataset = dataset,

            # For optional training + evaluation
            # train_dataset = new_dataset["train"],
            # eval_dataset = new_dataset["test"],
            )
    return trainer
        