import os
import shutil
from dataclasses import asdict
from typing import cast

import pandas as pd
import torch
from accelerate import PartialState
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    GenerationConfig,
    HfArgumentParser,
    PreTrainedTokenizer,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    TrainingArguments,
)
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
from trl.models.utils import unwrap_model_for_generation
from trl.trainer.ppo_trainer import PolicyAndValueWrapper, PPOConfig, PPOTrainer
from trl.trainer.utils import (
    SIMPLE_CHAT_TEMPLATE,
    batch_generation,
    get_reward,
    truncate_response,
)

import wandb
from datasets import Dataset
from dr.dataset import get_uf_ppo_dataset
from dr.ppo_trainer import DRPPOTrainer
from dr.utils import DRConfig, PPOLossType, get_rlhf_run_name


class LogCompletionsCallback(TrainerCallback):
    trainer: PPOTrainer
    eval_dataset: Dataset
    save_dir: str

    counter: int

    def __init__(self, trainer: PPOTrainer, eval_dataset: Dataset, run_name: str, log_completions_interval: int = 125) -> None:
        super().__init__()

        self.trainer = trainer
        self.save_dir = f"eval_rlhf/{run_name}/"
        self.eval_dataset = eval_dataset
        self.log_completions_interval = log_completions_interval

        self.counter = 0

        os.makedirs(self.save_dir, exist_ok=True)

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        accelerator = self.trainer.accelerator
        trainer = self.trainer

        if state.global_step % self.log_completions_interval != 0 and state.global_step != 1 and self.log_completions_interval != 1:
            if accelerator.is_main_process:
                self.counter += 1
            return

        model: PolicyAndValueWrapper = kwargs["model"]
        tokenizer: PreTrainedTokenizer = kwargs["processing_class"]

        full_prompts = []
        full_responses = []
        scores = []

        eval_data_loader = DataLoader(self.eval_dataset, batch_size=trainer.args.per_device_eval_batch_size, drop_last=False, collate_fn=trainer.data_collator)  # type: ignore
        eval_data_loader = accelerator.prepare(eval_data_loader)
        eval_data_loader = cast(DataLoader, eval_data_loader)

        generation_config = GenerationConfig(
            max_new_tokens=512,  # trainer.args.response_length,  # <- default?
            top_p=1.0,  # greedy search
            do_sample=False,  # default is True
            begin_suppress_tokens=[tokenizer.eos_token_id],
        )

        with unwrap_model_for_generation(model, accelerator) as unwrapped_model:  # type: ignore
            with torch.no_grad():
                for batch in tqdm(eval_data_loader, desc="Generate Completions"):
                    query = batch["input_ids"]
                    context_length = query.shape[1]
                    query_responses, _ = batch_generation(
                        unwrapped_model.policy,
                        query,
                        query.shape[0],
                        tokenizer.pad_token_id,  # type: ignore
                        generation_config,
                    )

                    response = query_responses[:, context_length:]
                    postprocessed_response = response

                    if trainer.args.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
                        postprocessed_response = truncate_response(trainer.args.stop_token_id, tokenizer.pad_token_id, response)  # type: ignore

                    full_prompts.extend(accelerator.gather_for_metrics(tokenizer.batch_decode(query, skip_special_tokens=True)))
                    full_responses.extend(accelerator.gather_for_metrics(tokenizer.batch_decode(postprocessed_response)))

                    postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
                    _, score, _ = get_reward(trainer.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length)  # type: ignore
                    scores.extend(accelerator.gather_for_metrics(score).float().cpu().numpy())  # type: ignore

        torch.cuda.empty_cache()

        if accelerator.is_main_process:
            evaluation_result = {
                "prompts": full_prompts,
                "responses": full_responses,
                "scores": scores,
            }

            if state.epoch is None:
                state.epoch = 0

            save_dir = os.path.join(self.save_dir, f"epoch{int(state.epoch)}step{self.counter:05d}_completions.csv")
            dataframe = pd.DataFrame(evaluation_result)
            dataframe.to_csv(save_dir)
            print(f"Saved to '{save_dir}'.")

            trainer.log({"val/completions_count": self.counter})

            self.counter += 1


if __name__ == "__main__":
    parser = HfArgumentParser((PPOConfig, ModelConfig, DRConfig, PPOLossType))  # type: ignore
    config, model_config, dr_config, ppo_loss_type = parser.parse_args_into_dataclasses()

    config = cast(PPOConfig, config)
    model_config = cast(ModelConfig, model_config)
    dr_config = cast(DRConfig, dr_config)
    ppo_loss_type = cast(PPOLossType, ppo_loss_type)
    max_length = 1024

    if model_config.model_name_or_path is None:
        raise RuntimeError("No model specified")

    print(dr_config, "\n", ppo_loss_type, "\n")

    run_name = get_rlhf_run_name("ppo", config.reward_model_path, model_config.model_name_or_path, dr_config, config, ppo_loss_type)
    output_dir = "models/" + run_name
    config.output_dir = output_dir

    print(f"Output directory: '{output_dir}'")
    print(f"wandb run name: '{run_name}'")

    if PartialState().is_main_process and config.report_to is not None and "wandb" in config.report_to:
        wandb_run = wandb.init(project="dr-rlhf", name=run_name, config={**asdict(dr_config)}, allow_val_change=True)
        wandb_run.tags = wandb_run.tags + ("ppo", "dataset-uf")

    # remove output_dir if exists
    # TODO: do we really want to remove without confirmation?
    shutil.rmtree(config.output_dir, ignore_errors=True)

    ################
    # Model & Tokenizer
    ################
    torch_dtype = model_config.torch_dtype if model_config.torch_dtype in ["auto", None] else getattr(torch, model_config.torch_dtype)  # type: ignore
    quantization_config = get_quantization_config(model_config)
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path,
        padding_side="left",
        use_fast=True,
        clean_up_tokenization_spaces=True,
        trust_remote_code=model_config.trust_remote_code,
    )
    if tokenizer.pad_token is None:
        print("pad_token is None, replacing pad token with eos token")
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    if tokenizer.chat_template is None:
        print("chat_template is None, replacing chat template with SIMPLE_CHAT_TEMPLATE")
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

    rew_kwargs = dict(
        attn_implementation=model_config.attn_implementation,
        trust_remote_code=model_config.trust_remote_code,
        torch_dtype=torch_dtype,
        use_cache=False,
    )

    policy_kwargs = dict(
        revision=model_config.model_revision,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        use_cache=False,
        attn_implementation=model_config.attn_implementation,
        trust_remote_code=model_config.trust_remote_code,
        torch_dtype=torch_dtype,
    )
    value_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1, **rew_kwargs)
    reward_model = AutoModelForSequenceClassification.from_pretrained(config.reward_model_path, num_labels=1, **rew_kwargs)

    ref_policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path, **policy_kwargs)
    policy = AutoModelForCausalLM.from_pretrained(config.sft_model_path, **policy_kwargs)

    if config.sft_model_path.startswith("google/gemma-2"):
        print("Attention implementation:", policy.config._attn_implementation)

    ################
    # Dataset
    ################
    raw_datasets = get_uf_ppo_dataset(tokenizer, max_length, config.dataset_num_proc, version=dr_config.dataset_version, remove_columns=True)

    train_dataset = raw_datasets["train"]
    eval_dataset = raw_datasets["val"]

    ################
    # Training
    ################
    if dr_config.eps != 0:
        PPOv2TrainerClass = DRPPOTrainer
        extra_kwargs = {"eps": dr_config.eps, "dist_fn": dr_config.dist_fn, "loss_type": ppo_loss_type.loss_type}
        print(f"Using DR RLHF, epsilon={dr_config.eps}")
    else:
        PPOv2TrainerClass = PPOTrainer
        extra_kwargs = {}

    trainer = PPOv2TrainerClass(
        args=config,
        processing_class=tokenizer,
        model=policy,
        ref_model=ref_policy,
        reward_model=reward_model,
        value_model=value_model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        **extra_kwargs,  # type: ignore
    )
    if dr_config.log_completions_interval is not None:
        trainer.add_callback(LogCompletionsCallback(trainer, eval_dataset, run_name, dr_config.log_completions_interval))

    trainer.train()
    trainer.save_model(config.output_dir)
    print(f"Saved model to '{config.output_dir}'")

    trainer.generate_completions()
