""" Task to train claim rewriter """

import os
import uuid
from overrides import overrides
from datasets import load_from_disk
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig
from tasker import BaseTask
from typing import (
    Any,
    List,
    Dict,
    Text,
    Tuple,
    Callable
)


@BaseTask.register('train-claim-rewriter')
class TrainClassRewriterTask(BaseTask):
    __VERSION__ = '0.0.6'

    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        learning_rate: float,
        model_name: Text
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        self._learning_rate = learning_rate
        self._model_name = model_name
        self._partial_state = PartialState()
        
        assert self._model_name == "meta-llama/Meta-Llama-3-8B-Instruct", "Only Meta-Llama model is supported."
        
        self._train_dataset = load_from_disk(os.path.join(self._input_dir, 'train'))
        self._test_dataset = load_from_disk(os.path.join(self._input_dir, 'test'))
        
        self._peft_config = LoraConfig(
            r=16,
            lora_alpha=32,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
        
        self._model = AutoModelForCausalLM.from_pretrained(
            self._model_name,
            load_in_8bit=False,
            device_map={
                "": self._partial_state.process_index
            },
        )
        self._tokenizer = AutoTokenizer.from_pretrained(self._model_name)
        self._tokenizer.pad_token = '<|end_of_text|>'
        
        self._trainer = SFTTrainer(
            self._model,
            data_collator=DataCollatorForCompletionOnlyLM(
                instruction_template=[128006, 882, 128007, 271],
                response_template=[128006, 78191, 128007, 271],
                tokenizer=self._tokenizer
            ),
            # dataset_text_field="text",
            processing_class=self._tokenizer,
            train_dataset=self._train_dataset,
            eval_dataset=self._test_dataset,
            args=SFTConfig(
                max_seq_length=256,
                metric_for_best_model="eval_loss",
                learning_rate=self._learning_rate,
                num_train_epochs=10,
                eval_strategy="epoch",
                output_dir=self._output_dir,
                save_total_limit=3,
                save_strategy="epoch",
                report_to="wandb",
                run_name="claim_rewriter_training_sft" + str(uuid.uuid4()),
                lr_scheduler_type="constant",
            ),
            peft_config=self._peft_config
        )
        
    @overrides
    def _run(self):
        os.environ["WANDB_PROJECT"] = "claim-rewriting"
        self._trainer.train()
        return self._trainer
        
    @overrides
    def _write(self, outputs):
        ...