from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import torch
import transformers
import trl
from struq import SupervisedDataset
from config import (
    IGNORE_INDEX,
    DEFAULT_TOKENS,
    SPECIAL_DELM_TOKENS,
    TEXTUAL_DELM_TOKENS,
)

from transformers import TrainerCallback, TrainerState, TrainerControl
import os


class GradientNormCallback(TrainerCallback):
    def __init__(self, threshold):
        self.threshold = threshold

    def on_log(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        if (
            ("logs" in kwargs)
            and ("grad_norm" in kwargs["logs"])
            and (kwargs["logs"]["grad_norm"] >= self.threshold)
        ):
            print(
                f"[!] Gradient norm {kwargs['logs']['grad_norm']:.2f} exceeds threshold {self.threshold}, saving model..."
            )
            control.should_save = True
        return control


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    window_size: int = field(
        default=0, metadata={"help": "Window size for the sliding window attention."}
    )
    padding_side: str = field(
        default="right", metadata={"help": "Padding side for tokenization."}
    )


@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )


@dataclass
class AttackArguments:
    attack: str = field(
        default="TextTextText_None", metadata={"help": "Attack type for SFT/Align"}
    )
    alignment: str = field(default="none", metadata={"help": "Alignment type."})


@dataclass
class TrainingArguments(trl.ORPOConfig):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    downsample: Optional[bool] = field(default=True)
    lr_scale: Optional[bool] = field(default=True)
    beta: float = field(default=0.1)
    ref_model_init_kwargs: Optional[str] = field(default=None)
    precompute_ref_log_probs: Optional[bool] = field(default=False)
    desirable_weight: Optional[float] = field(default=1)
    undesirable_weight: Optional[float] = field(default=1)


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple(
            [instance[key] for instance in instances] for key in ("input_ids", "labels")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def get_embedding_indices(tokenizer):
    init_values = [
        tokenizer.encode(v, add_special_tokens=False)[0] for v in TEXTUAL_DELM_TOKENS
    ]
    ignore_values = [i for i in range(len(tokenizer)) if tokenizer.decode(i) == "#"]
    return init_values, ignore_values


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    REAL_DELIMITERS_INIT_EMBD_IND, _ = get_embedding_indices(tokenizer)

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True
        )

        input_embeddings[-num_new_tokens] = input_embeddings_avg
        output_embeddings[-num_new_tokens] = output_embeddings_avg

        for i in range(
            len(SPECIAL_DELM_TOKENS)
        ):  ### initialize real delimiter's embedding by the existing ones
            input_embeddings[-num_new_tokens + i + 1] = input_embeddings[
                REAL_DELIMITERS_INIT_EMBD_IND[i]
            ]
            output_embeddings[-num_new_tokens + i + 1] = output_embeddings[
                REAL_DELIMITERS_INIT_EMBD_IND[i]
            ]


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args, downsample=True
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = SupervisedDataset(
        tokenizer=tokenizer,
        data_path=data_args.data_path,
        attack=data_args.attack,
        downsample=downsample,
    )
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    return dict(
        train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
    )


def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, AttackArguments)
    )
    model_args, data_args, training_args, attack_args = (
        parser.parse_args_into_dataclasses()
    )
    os.makedirs(training_args.output_dir, exist_ok=True)
    if training_args.gradient_checkpointing:
        training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
    data_args.attack = attack_args.attack
    if "Instruct" in model_args.model_name_or_path:
        assert "SpclSpclSpcl" not in data_args.attack
    print("\n\n" + training_args.output_dir + "\n\n")

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
    )

    if model_args.window_size > 0:
        model.config.window = model_args.window_size

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side=model_args.padding_side,
        use_fast=False,
    )

    special_tokens_dict = dict()
    special_tokens_dict["pad_token"] = DEFAULT_TOKENS["pad_token"]  ###
    special_tokens_dict["eos_token"] = DEFAULT_TOKENS["eos_token"]
    special_tokens_dict["bos_token"] = DEFAULT_TOKENS["bos_token"]
    special_tokens_dict["unk_token"] = DEFAULT_TOKENS["unk_token"]
    special_tokens_dict["additional_special_tokens"] = SPECIAL_DELM_TOKENS  ###
    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model
    )

    data_module = make_supervised_data_module(
        tokenizer=tokenizer, data_args=data_args, downsample=training_args.downsample
    )
    if not training_args.downsample and training_args.lr_scale:
        training_args.learning_rate /= data_module["train_dataset"].data_copy_count

    threshold = 4.5 if "mistral" not in model_args.model_name_or_path.lower() else 7
    trainer = transformers.Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        callbacks=[GradientNormCallback(threshold=threshold)],
        **data_module,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)


if __name__ == "__main__":
    train()
