""" Task to train claim rewriter (with DPO) """

import os
import uuid
from overrides import overrides
from datasets import load_from_disk
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import (
    DPOConfig,
    DPOTrainer,
    DataCollatorForCompletionOnlyLM,
    apply_chat_template
)
from peft import LoraConfig
from tasker import BaseTask
from typing import (
    Any,
    List,
    Dict,
    Text,
    Tuple,
    Callable
)


@BaseTask.register('dpo-train-claim-rewriter')
class DPOTrainClaimRewriterTask(BaseTask):
    __VERSION__ = '0.0.2'

    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        learning_rate: float,
        model_name: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._learning_rate = learning_rate
        self._model_name = model_name
        self._partial_state = PartialState()
        
        # assert self._model_name == "meta-llama/Meta-Llama-3-8B-Instruct", "Only Meta-Llama model is supported."
        
        self._train_dataset = load_from_disk(os.path.join(self._input_dir, 'train'))
        self._test_dataset = load_from_disk(os.path.join(self._input_dir, 'test'))
        
        # remove the following code block
        self._train_dataset = self._train_dataset.remove_columns(["source"])
        self._test_dataset = self._test_dataset.remove_columns(["source"])
        
        self._peft_config = LoraConfig(
            r=16,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        self._model = AutoModelForCausalLM.from_pretrained(
            self._model_name,
            load_in_8bit=False,
            device_map={
                "": self._partial_state.process_index
            },
        )
        self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
        self._tokenizer.pad_token = '<|end_of_text|>'
        
        self._train_dataset = self._train_dataset.map(
            apply_chat_template,
            fn_kwargs={
                "tokenizer": self._tokenizer,
            }
        )
        
        self._test_dataset = self._test_dataset.map(
            apply_chat_template,
            fn_kwargs={
                "tokenizer": self._tokenizer,
            }
        )
        
        self._trainer = DPOTrainer(
            self._model,
            # data_collator=DataCollatorForCompletionOnlyLM(
            #     instruction_template=[128006, 882, 128007, 271],
            #     response_template=[128006, 78191, 128007, 271],
            #     tokenizer=self._tokenizer
            # ),
            # dataset_text_field="text",
            processing_class=self._tokenizer,
            train_dataset=self._train_dataset,
            eval_dataset=self._test_dataset,
            args=DPOConfig(
                max_length=512,
                max_prompt_length=256,
                metric_for_best_model="eval_loss",
                learning_rate=self._learning_rate,
                num_train_epochs=15,
                eval_strategy="epoch",
                output_dir=self._output_dir,
                save_total_limit=10,
                save_strategy="epoch",
                report_to="wandb",
                run_name="claim_rewriter_training_dpo" + str(uuid.uuid4()),
                lr_scheduler_type="constant",
            ),
            peft_config=self._peft_config
        )
        
    @overrides
    def _run(self):
        os.environ["WANDB_PROJECT"] = "claim-rewriting-dpo"
        self._trainer.train()
        return self._trainer
        
    @overrides
    def _write(self, outputs):
        ...