import logging
from functools import partial
from typing import Any, Dict, List

import numpy as np
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
    set_seed,
)

from rl4lms.data_pools.text_generation_pool import Sample
from rl4lms.envs.text_generation.env import TextGenEnv
from rl4lms.envs.text_generation.evaluation_utils import evaluate_on_samples
from rl4lms.envs.text_generation.logging_utils import Tracker
from rl4lms.envs.text_generation.registry import (
    AlgorithmRegistry,
    DataPoolRegistry,
    MetricRegistry,
    PolicyRegistry,
    RewardFunctionRegistry,
    WrapperRegistry,
)
from rl4lms.envs.text_generation.reward import RewardFunction
from rl4lms.envs.text_generation.utils_supervised import EvalCallack
from rl4lms.envs.text_generation.utils_supervised import (
    evaluate_on_samples as evaluate_supervised,
)
from rl4lms.envs.text_generation.utils_supervised import (
    get_datasets_for_causal,
    get_datasets_for_seq2seq,
    tokenize_causal,
    tokenize_seq2seq,
)
from rl4lms.envs.text_generation.warm_start import TrainerWarmStartMixin

logger = logging.getLogger(__name__)


def build_tokenizer(tokenizer_config: Dict[str, Any]):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["model_name"])
    if tokenizer.pad_token is None and tokenizer_config.get(
        "pad_token_as_eos_token", True
    ):
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = tokenizer_config.get("padding_side", "left")
    tokenizer.truncation_side = tokenizer_config.get("truncation_side", "left")
    return tokenizer


def build_reward_fn(reward_config: Dict[str, Any]):
    reward_fn = RewardFunctionRegistry.get(
        reward_config["id"], reward_config.get("args", {})
    )
    return reward_fn


def build_metrics(metric_configs: List[Dict[str, Any]]):
    metrics = [
        MetricRegistry.get(metric_config["id"], metric_config.get("args", {}))
        for metric_config in metric_configs
    ]
    return metrics


def build_datapool(datapool_config: Dict[str, Any]):
    def _get_datapool_by_split(split: str):
        kwargs = datapool_config.get("args", {})
        kwargs["split"] = split
        dp_split = DataPoolRegistry.get(datapool_config["id"], kwargs)
        return dp_split

    train_datapool = _get_datapool_by_split("train")
    val_datapool = _get_datapool_by_split("val")
    test_datapool = _get_datapool_by_split("test")

    samples_by_split = {
        "train": [(sample, weight) for sample, weight in train_datapool],
        "val": [sample for sample, _ in val_datapool],
        "test": [sample for sample, _ in test_datapool],
    }
    return samples_by_split


def build_env(
    env_config: Dict[str, Any],
    reward_fn: RewardFunction,
    tokenizer: AutoTokenizer,
    train_samples: List[Sample],
):
    # vectoried env
    env_kwargs = {
        "reward_function": reward_fn,
        "tokenizer": tokenizer,
        "samples": train_samples,
    }
    env_kwargs = {**env_kwargs, **env_config.get("args", {})}
    env = make_vec_env(
        TextGenEnv,
        n_envs=env_config.get("n_envs", 1),
        vec_env_cls=SubprocVecEnv,
        env_kwargs=env_kwargs,
    )
    return env


def build_alg(
    alg_config: Dict[str, Any],
    env: TextGenEnv,
    tracker: Tracker,
    policy_state: Dict[str, Any],
    alg_state: Dict[str, Any],
):
    # TBD - move these to a registry once the experimentation is done
    # Also switch to Sb3 algos when possible with minimal code adaptations
    policy_config = alg_config["policy"]
    policy_cls = PolicyRegistry.get(policy_config["id"])
    alg_cls = AlgorithmRegistry.get(alg_config["id"])

    policy_args = policy_config["args"]
    policy_args["state_dict"] = policy_state
    alg_kwargs = {
        "policy": policy_cls,
        "env": env,
        "policy_kwargs": policy_args,
    }
    alg_kwargs = {**alg_kwargs, **alg_config.get("args")}
    wrapper = WrapperRegistry.get(alg_config["id"])
    alg = wrapper(
        alg_cls,
        alg_kwargs,
        alg_config["kl_div"]["coeff"],
        tracker,
        alg_config["kl_div"].get("target_kl", None),
        alg_config["kl_div"].get("norm_reward", False),
    )
    alg.load_from_dict(alg_state)
    return alg


class OnPolicyTrainer(TrainerWarmStartMixin):
    """
    A generic trainer for training LMs with onpolicy algorithms from SB3
    """

    def __init__(
        self,
        tokenizer_config: Dict[str, Any],
        datapool_config: Dict[str, Any],
        reward_config: Dict[str, Any],
        env_config: Dict[str, Any],
        on_policy_alg_config: Dict[str, Any],
        train_eval_config: Dict[str, Any],
        tracker: Tracker = None,
        experiment_name: str = "",
    ):
        self._tokenizer_config = tokenizer_config
        self._datapool_config = datapool_config
        self._reward_config = reward_config
        self._env_config = env_config
        self._on_policy_alg_config = on_policy_alg_config
        self._train_eval_config = train_eval_config
        self._tracker = tracker
        self._experiment_name = experiment_name
        self._setup()

    def _setup(self):
        # load trainer state from available previous checkpoint if available
        self.load_trainer_state(self._tracker)

        # build components
        self._tokenizer = build_tokenizer(self._tokenizer_config)
        self._reward_fn = build_reward_fn(self._reward_config)
        self._metrics = build_metrics(self._train_eval_config.get("metrics", []))
        self._samples_by_split = build_datapool(self._datapool_config)
        self._env = build_env(
            self._env_config,
            self._reward_fn,
            self._tokenizer,
            self._samples_by_split["train"],
        )
        self._alg = build_alg(
            self._on_policy_alg_config,
            self._env,
            self._tracker,
            self._policy_state_dict,
            self._alg_state_dict,
        )

        # extract train params
        self._max_episode_length = self._env_config["args"]["max_episode_length"]
        self._max_prompt_length = self._env_config["args"]["max_prompt_length"]
        self._eval_batch_size = self._train_eval_config["eval_batch_size"]
        self._n_iters = int(self._train_eval_config["n_iters"])
        self._n_steps_per_iter = self._env.num_envs * self._alg.n_steps

        # gen kwargs for evaluation (if it is different from rollout gen kwargs)
        _seed = self._train_eval_config.get("seed", None)
        self._seed = None if _seed is None else int(_seed)
        self._save_every = self._train_eval_config.get("save_every")
        self._eval_every = self._train_eval_config.get("eval_every")
        self._eval_gen_kwargs = self._train_eval_config.get("generation_kwargs", None)

        self.eval_before_train = self._train_eval_config.get("eval_before_train", True)

        # resetting
        _reset_freq = self._train_eval_config.get("reset_freq", None)
        self._reset_freq = None if _reset_freq is None else int(_reset_freq)
        self._reset_ema = bool(self._train_eval_config.get("reset_ema", False))
        _reset_ema_freq = self._train_eval_config.get("reset_ema_freq", None)
        self._reset_ema_freq = (
            _reset_freq if _reset_ema_freq is None else int(_reset_ema_freq)
        )
        self._reset_opt = bool(self._train_eval_config.get("reset_opt", False))
        self._reset_critic = bool(self._train_eval_config.get("reset_critic", False))
        self._freeze_policy = bool(self._train_eval_config.get("freeze_policy", False))
        # tuple of ints
        _freeze_value_epochs = self._train_eval_config.get("freeze_value_epochs", None)
        if _freeze_value_epochs is None:
            self._freeze_value_epochs = None
        elif "," in _freeze_value_epochs:
            self._freeze_value_epochs = tuple(
                int(epoch) for epoch in _freeze_value_epochs.split(",")
            )
        else:
            self._freeze_value_epochs = (0, int(_freeze_value_epochs))

        # ema
        self._ref_causal_perplexity = bool(
            self._train_eval_config.get("ref_causal_perplexity", False)
        )
        self._ref_learned_metric = bool(
            self._train_eval_config.get("ref_learned_metric", False)
        )
        self._separate_ema_model = bool(
            self._train_eval_config.get("separate_ema_model", False)
        )
        if self._separate_ema_model:
            self._ema_model = AutoModelForCausalLM.from_pretrained(
                self._alg.policy._model_name
            )
            self._ema_model.to(self._alg.policy._ref_model.device)
            self._ema_model.load_state_dict(self._alg.policy._ref_model.state_dict())
            self._alg.set_ema_model(self._ema_model)
        else:
            self._ema_model = self._alg.policy._ref_model
            if hasattr(self._alg, "set_ema_model"):
                self._alg.set_ema_model(self._ema_model)

    def _evaluate_on_datapools(self, epoch: int, splits: List[str] = ["val", "test"]):
        for split in splits:
            evaluate_on_samples(
                policy=self._alg.policy,
                tokenizer=self._tokenizer,
                samples=self._samples_by_split[split],
                batch_size=self._eval_batch_size,
                max_prompt_length=self._max_prompt_length,
                metrics=self._metrics,
                epoch=epoch,
                split_name=split,
                tracker=self._tracker,
                gen_kwargs=self._eval_gen_kwargs,
                ref_causal_perplexity=self._ref_causal_perplexity,
                ref_learned_metric=self._ref_learned_metric,
            )

    def train_and_eval(self):
        # evaluate on val and test set before fine-tuning once
        if self._seed is not None:
            set_seed(self._seed)

        iter_start = self._trainer_state["current_iter"]
        epoch = iter_start
        if self.eval_before_train:
            self._evaluate_on_datapools(epoch=iter_start)

        if self._freeze_policy:
            for p in self._alg.policy._policy_model.parameters():
                p.requires_grad = False

        # train for given number of iters
        for epoch in range(iter_start, self._n_iters):
            # current state
            self._trainer_state["current_iter"] = epoch
            if self._freeze_value_epochs is not None:
                # freeze at start epoch
                if epoch == self._freeze_value_epochs[0]:
                    for p in self._alg.policy._value_model.parameters():
                        p.requires_grad = False
                # unfreeze at end epoch
                elif epoch == self._freeze_value_epochs[1]:
                    for p in self._alg.policy._value_model.parameters():
                        p.requires_grad = True

            # inner rollout and learn loop for on-policy algorithm
            self._alg.learn(self._n_steps_per_iter)

            if self._reset_freq is not None and epoch % self._reset_freq == 0:
                logger.info("resetting policy to ref")
                self._alg.policy._policy_model.load_state_dict(
                    self._ema_model.state_dict(), strict=False
                )

                if self._reset_opt:
                    self._alg.policy._setup_optimizer(
                        optimizer_kwargs=self._alg.policy._optimizer_kwargs,
                        weight_decay=self._alg.policy._weight_decay,
                        optimizer_class=self._alg.policy._optimizer_class,
                    )

                if self._reset_critic:
                    self._alg.policy._value_model.load_state_dict(
                        self._alg.policy._ref_model.state_dict(), strict=False
                    )
                    if hasattr(self._alg.policy._value_head, "reset_parameters"):
                        self._alg.policy._value_head.reset_parameters()

            if (
                self._reset_ema
                and self._reset_ema_freq is not None
                and epoch % self._reset_ema_freq == 0
            ):
                logger.info("resetting ema to pretrained")
                pretrained_model = AutoModelForCausalLM.from_pretrained(
                    self._alg.policy._model_name
                )
                self._ema_model.load_state_dict(
                    pretrained_model.state_dict(), strict=False
                )
                self._ema_model.to(self._alg.policy.device)
                self._ema_model.eval()

            # save the policy checkpoint
            if self._save_every is not None and (epoch + 1) % self._save_every == 0:
                self.save_trainer_state(
                    self._tracker, self._alg.policy, self._trainer_state
                )

            # evaluate on val set in the given intervals
            if self._eval_every is not None and (epoch + 1) % self._eval_every == 0:
                self._evaluate_on_datapools(epoch=epoch, splits=["val"])

        # finally evaluate on val and test samples
        self._evaluate_on_datapools(epoch=epoch)

        # save model here - we save only the language model
        if self._tracker is not None and self._n_iters > 0:
            self._tracker.save_auto_model(self._alg.policy.get_language_model())
            self._tracker.save_auto_model(self._alg.policy._value_model, "value_model")


class SupervisedTrainer:
    """
    A supervised trainer to train LMs (causal and seq2seq) on text generation tasks (wrapper on HF trainer)
    """

    def __init__(
        self,
        tokenizer_config: Dict[str, Any],
        datapool_config: Dict[str, Any],
        train_eval_config: Dict[str, Any],
        alg_config: Dict[str, Any],
        tracker: Tracker = None,
    ):
        self._tokenizer_config = tokenizer_config
        self._datapool_config = datapool_config
        self._train_eval_config = train_eval_config
        self._alg_config = alg_config
        self._tracker = tracker
        self._setup()

    def _evaluate_on_datapools(self, epoch: int, splits: List[str] = ["val", "test"]):
        for split in splits:
            evaluate_supervised(
                model=self._model,
                tokenizer=self._tokenizer,
                samples=self._samples_by_split[split],
                batch_size=self._eval_batch_size,
                max_prompt_length=self._max_prompt_length,
                metrics_config_dict=self._metrics_config_dict,
                epoch=epoch,
                split_name=split,
                tracker=self._tracker,
                generation_kwargs=self._gen_kwargs,
            )

    def _setup(self):
        self._tokenizer = build_tokenizer(self._tokenizer_config)
        self._metrics_config_dict = self._train_eval_config.get("metrics")
        self._samples_by_split = build_datapool(self._datapool_config)
        self._train_dataset = (
            get_datasets_for_causal(self._samples_by_split["train"])
            if self._alg_config["model_type"] == "causal"
            else get_datasets_for_seq2seq(self._samples_by_split["train"])
        )
        preprocess_fn = (
            tokenize_causal
            if self._alg_config["model_type"] == "causal"
            else tokenize_seq2seq
        )
        preprocess_fn = partial(preprocess_fn, tokenizer=self._tokenizer)
        self._tokenized_dataset = self._train_dataset.map(
            preprocess_fn, batched=True, remove_columns=self._train_dataset.column_names
        )
        model_cls = (
            AutoModelForCausalLM
            if self._alg_config["model_type"] == "causal"
            else AutoModelForSeq2SeqLM
        )
        self._gen_kwargs = self._alg_config["generation_kwargs"]
        self._model = model_cls.from_pretrained(self._alg_config["model_name"])
        self._model.parallelize()
        self._eval_batch_size = self._train_eval_config["eval_batch_size"]

        # setting max prompt length
        self._max_prompt_length = self._tokenizer_config.get(
            "max_length", self._tokenizer.model_max_length
        )

        if (self._alg_config["model_type"] == "causal") and (
            (self._max_prompt_length + self._gen_kwargs["max_new_tokens"])
            > self._tokenizer.model_max_length
        ):
            self._max_prompt_length = (
                self._max_prompt_length - self._gen_kwargs["max_new_tokens"]
            )

        self._eval_callback = EvalCallack(
            self._samples_by_split["val"],
            self._gen_kwargs,
            self._eval_batch_size,
            self._tokenizer,
            self._metrics_config_dict,
            self._max_prompt_length,
            self._tracker,
        )
        train_args = self._alg_config["training_args"]
        train_args["output_dir"] = self._tracker.checkpoint_base_path
        train_args["seed"] = np.random.randint(1e2)  # random seed
        self._train_args = TrainingArguments(**train_args)
        data_collator = (
            DataCollatorForLanguageModeling(self._tokenizer, mlm=False)
            if self._alg_config["model_type"] == "causal"
            else DataCollatorForSeq2Seq(self._tokenizer, self._model)
        )
        self._trainer = Trainer(
            model=self._model,
            tokenizer=self._tokenizer,
            args=self._train_args,
            data_collator=data_collator,
            train_dataset=self._tokenized_dataset,
            callbacks=[self._eval_callback],
        )

    def train_and_eval(self):
        # evaluate on val and test set before fine-tuning once
        self._evaluate_on_datapools(epoch=0)

        # train using HF trainer
        self._trainer.train()

        # finally evaluate on val and test samples
        self._evaluate_on_datapools(epoch=self._train_args.num_train_epochs)

        # save model here - we save only the language model
        if self._tracker is not None:
            self._tracker.save_auto_model(self._model)
