import torch
import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable

import datasets
import numpy as np
from datasets import Dataset
from huggingface_hub import delete_repo, repo_exists
from pydantic import BaseModel
from transformers import TrainerCallback, AutoTokenizer, AutoModelForCausalLM
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.trainer_utils import _re_checkpoint
from trl import (
    DataCollatorForCompletionOnlyLM,
    SFTConfig,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    setup_chat_format,
)
from trl import ModelConfig as TRLModelConfig
from trl.models.utils import ChatMlSpecialTokens
from trl.trainer import SFTTrainer



from src.data.utils import CustomColName
from src.trainer.callbacks import ExpInfoCallback
from src.utils.logging_utils import get_logger
from src.utils.model_config import ModelConfig

logger = get_logger()


@dataclass
class CustomSFTConfig(SFTConfig):
    meta: dict | None = None


class ChatSFTTrainer(SFTTrainer):
    class Config(BaseModel):
        drop_long_examples: bool = True
        dataset_num_proc: int = 32
        signature_columns: list[str] = ["input_ids", "labels", "attention_mask"]

    def __init__(
        self,
        *,
        model_config: ModelConfig,
        model_init_config: TRLModelConfig,
        args: CustomSFTConfig,
        use_unsloth: bool = False,
        callbacks: list[TrainerCallback] = [],
        submit_eval_job_callback: Callable | None = None,
        **kwargs,
    ):
        self.config = self.Config(**kwargs)

        self.exp_info_callback = ExpInfoCallback()
        callbacks.append(self.exp_info_callback)

        if submit_eval_job_callback is not None:
            self.submit_eval_job_callback = submit_eval_job_callback()
            callbacks.append(self.submit_eval_job_callback)

        super().__init__(
            **self._load_model_and_tokenizer(
                model=kwargs.pop("model"),
                use_unsloth=use_unsloth,
                model_config=model_config,
                model_init_config=model_init_config,
                args=args,
            ),
            args=args,
            callbacks=callbacks,
            dataset_num_proc=self.config.dataset_num_proc,
            **{k: v for k, v in kwargs.items() if k not in self.config.model_dump().keys()},
        )

        self._set_data_collator()
        self._check_tokenized_datasets()
        self._droped_long_ex_cnt = []

    def init_hf_repo(self):
        if self.args.push_to_hub and self.args.hub_model_id is not None:
            if repo_exists(self.args.hub_model_id, token=self.args.hub_token, repo_type="model"):
                delete_repo(self.args.hub_model_id, token=self.args.hub_token)
                logger.warning(f"Deleted existing repo: {self.args.hub_model_id}")

        return super().init_hf_repo()

    def _merge_callbacks(self, callbacks):
        return callbacks + [self.exp_info_callback]

    def _prepare_non_packed_dataloader(
        self,
        processing_class,
        dataset,
        dataset_text_field: str,
        max_seq_length,
        formatting_func=None,
        add_special_tokens=True,
        remove_unused_columns=True,
    ):
        tokenized_dataset = self._custom_prepare_non_packed_dataloader(
            processing_class=processing_class,
            dataset=dataset,
            dataset_text_field=dataset_text_field,
            max_seq_length=max_seq_length,
            formatting_func=formatting_func,
            add_special_tokens=add_special_tokens,
            remove_unused_columns=remove_unused_columns,
        )

        if self.config.drop_long_examples:
            tokenized_dataset = self._drop_long_examples(tokenized_dataset, max_seq_length, processing_class)

        useful_cols = self.config.signature_columns + [CustomColName.ID.value, CustomColName.DS_ID.value]
        tokenized_dataset = tokenized_dataset.remove_columns(
            list(set(tokenized_dataset.column_names) - set(useful_cols))
        )

        tokenized_dataset = self._add_idx_column(tokenized_dataset)

        return tokenized_dataset

    def _custom_prepare_non_packed_dataloader(
        self,
        processing_class,
        dataset,
        dataset_text_field: str,
        max_seq_length,
        formatting_func=None,
        add_special_tokens=True,
        remove_unused_columns=True,
    ):
        # Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
        def tokenize(element):
            outputs = processing_class(
                element[dataset_text_field] if formatting_func is None else formatting_func(element),
                add_special_tokens=add_special_tokens,
                truncation=True,
                padding=False,
                max_length=max_seq_length,
                return_overflowing_tokens=False,  # MODIFIED
                return_length=False,
            )

            if formatting_func is not None and not isinstance(formatting_func(element), list):
                raise ValueError(
                    "The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
                )

            return {
                "input_ids": outputs["input_ids"],
                "attention_mask": outputs["attention_mask"],
                CustomColName.TRUNCATED.value: [
                    len(outputs.encodings[i].overflowing) > 0 for i in range(len(outputs.encodings))
                ],  # MODIFIED
            }

        # START - MODIFIED: We clean the columns later on
        # if dataset.column_names is not None:  # None for IterableDataset
        #     extra_columns = list(set(dataset.column_names) - set(self.config.signature_columns))
        # else:
        #     extra_columns = []

        # if not remove_unused_columns and len(extra_columns) > 0:
        #     warnings.warn(
        #         "You passed `remove_unused_columns=False` on a non-packed dataset. This might create some issues with the default collator and yield to errors. If you want to "
        #         f"inspect dataset other columns (in this case {extra_columns}), you can subclass `DataCollatorForLanguageModeling` in case you used the default collator and create your own data collator in order to inspect the unused dataset columns."
        #     )
        # END - MODIFIED

        map_kwargs = {
            "batched": True,
            "remove_columns": dataset.column_names if remove_unused_columns else None,
            "batch_size": self.dataset_batch_size,
        }
        if isinstance(dataset, datasets.Dataset):
            map_kwargs["num_proc"] = self.dataset_num_proc  # this arg is not available for IterableDataset
        tokenized_dataset = dataset.map(tokenize, **map_kwargs)

        return tokenized_dataset

    def _drop_long_examples(self, tokenized_dataset, max_seq_length: int, processing_class):
        logger.info(f"Dropping examples longer than {max_seq_length}")

        chat_tokens = ChatMlSpecialTokens()
        assistant_tag = chat_tokens.assistant + "\n"
        response_token_ids = processing_class.encode(assistant_tag, add_special_tokens=False)

        def _tag_long_examples(examples):
            tags = []
            for i in range(len(examples["input_ids"])):
                response_token_ids_idxs = []
                for assistant_idx in np.where(np.array(examples["input_ids"][i]) == response_token_ids[0])[0]:
                    # find the indexes of the start of a response.
                    if (
                        response_token_ids
                        == examples["input_ids"][i][assistant_idx : assistant_idx + len(response_token_ids)]
                    ):
                        response_token_ids_idxs.append(assistant_idx + len(response_token_ids))

                tags.append(len(response_token_ids_idxs) == 0)

            examples[CustomColName.TRUNCATED.value] = tags
            return examples

        tokenized_dataset = tokenized_dataset.map(
            _tag_long_examples,
            batched=True,
            num_proc=self.config.dataset_num_proc,
            desc="Tagging long examples",
        )

        original_len = len(tokenized_dataset)
        tokenized_dataset = tokenized_dataset.filter(
            lambda x: not x[CustomColName.TRUNCATED.value],
            num_proc=self.config.dataset_num_proc,
            desc="Filtering long examples",
        )  # Avoid nan loss

        self._droped_long_ex_cnt = original_len - len(tokenized_dataset)
        if self._droped_long_ex_cnt:
            logger.info(f"Examples longer than {max_seq_length} dropped cnt={self._droped_long_ex_cnt}")

        return tokenized_dataset

    def _add_idx_column(self, dataset: Dataset):
        dataset = dataset.add_column(CustomColName.IDX.value, list(range(len(dataset))))
        return dataset

    def _check_tokenized_datasets(self):
        if self.train_dataset is not None:
            logger.info(f"Train dataset: {self.train_dataset}")
            rand_idx = random.randint(0, len(self.train_dataset) - 1)
            assert self.train_dataset[rand_idx][CustomColName.IDX.value] == rand_idx

        if self.eval_dataset is not None:
            logger.info(f"Eval dataset: {self.eval_dataset}")

    def _load_model_and_tokenizer(
        self,
        *,
        model: Any,
        use_unsloth: bool,
        model_config: ModelConfig,
        model_init_config: TRLModelConfig,
        args: CustomSFTConfig,
    ):
        if use_unsloth:
            model, tokenizer = self._load_unsloth_model(
                model_config=model_config,
                model_init_config=model_init_config,
                train_args=args,
            )

            peft_config = None
        else:
            model = self.load_model(
                model_config=model_config,
                model_init_config=model_init_config,
                train_args=args,
            )
            peft_config = get_peft_config(model_init_config)
            tokenizer = AutoTokenizer.from_pretrained(model_config.name_or_path, use_fast=True)

        if tokenizer.padding_side != "right":
            tokenizer.padding_side = "right"
            logger.info("Padding side set to 'right'")

        if tokenizer.chat_template is None:
            logger.info("Setting up chat format")
            model, tokenizer = setup_chat_format(model, tokenizer, resize_to_multiple_of=64)


        return {
            "model": model,
            "processing_class": tokenizer,
            "peft_config": peft_config,
        }

    def _set_data_collator(self):
        chat_tokens = ChatMlSpecialTokens()
        self.data_collator = DataCollatorForCompletionOnlyLM(
            instruction_template=chat_tokens.user + "\n",
            response_template=chat_tokens.assistant + "\n",
            tokenizer=self.processing_class,
            mlm=False,
        )
        logger.info("Data collator set")

    def _load_unsloth_model(
        self,
        model_config: ModelConfig,
        model_init_config: TRLModelConfig,
        train_args: CustomSFTConfig,
    ):
        from unsloth import FastLanguageModel  # Keep import here since it modifies env vars on import
        from unsloth.chat_templates import get_chat_template

        assert not model_init_config.load_in_8bit, "8-bit loading is not supported for unsloth models"

        model, tokenizer = FastLanguageModel.from_pretrained(
            model_name=model_config.name_or_path,
            revision=model_config.revision,
            max_seq_length=train_args.max_seq_length,
            dtype=None,  # None for auto detection
            load_in_4bit=model_init_config.load_in_4bit,
            # token=Secrets().hf_token,
            # device_map={"": device_string},
        )

        tokenizer = get_chat_template(
            tokenizer,
            chat_template="chatml",
        )

        # Re-enable warnings that unsloth disables
        warnings.filterwarnings(action="always", category=UserWarning, module="torch")
        warnings.filterwarnings(action="always", category=UserWarning, module="huggingface_hub")
        warnings.filterwarnings(action="always", category=UserWarning, module="trl")
        warnings.filterwarnings(action="always", category=UserWarning, module="transformers")
        warnings.filterwarnings(action="always", category=RuntimeWarning, module="multiprocessing")
        warnings.filterwarnings(action="always", category=RuntimeWarning, module="multiprocess")

        return model, tokenizer

    def _gen_model_init_kwargs(
        self,
        model_config: ModelConfig,
        model_init_config: TRLModelConfig,
        train_args: CustomSFTConfig,
    ):
        quantization_config = get_quantization_config(model_init_config)
        model_init_kwargs = dict(
            revision=model_config.revision,
            trust_remote_code=model_config.trust_remote_code,
            attn_implementation=model_init_config.attn_implementation,
            torch_dtype=model_init_config.torch_dtype,
            use_cache=not train_args.gradient_checkpointing,
            device_map=get_kbit_device_map() if quantization_config is not None else None,
            quantization_config=quantization_config,
            token=train_args.hub_token,
            load_in_4bit=model_init_config.load_in_4bit,
            load_in_8bit=model_init_config.load_in_8bit,
        )
        return model_init_kwargs

    def get_sorted_checkpoints_w_steps(self):
        checkpoints = self._sorted_checkpoints()
        checkpoint_w_step = []
        for checkpoint in checkpoints:
            step = _re_checkpoint.search(checkpoint).groups()[0]
            checkpoint_w_step.append((int(step), checkpoint))

        return checkpoint_w_step

    def load_model(self, train_args: CustomSFTConfig, model_config: ModelConfig, model_init_config: TRLModelConfig):
        model_init_kwargs = self._gen_model_init_kwargs(
            model_config=model_config,
            model_init_config=model_init_config,
            train_args=train_args,
        )
        # Dtype
        torch_dtype = model_init_kwargs.get("torch_dtype")
        if torch_dtype is not None:
            # Convert to `torch.dtype` if an str is passed
            if isinstance(torch_dtype, str) and torch_dtype != "auto":
                torch_dtype = getattr(torch, torch_dtype)
            if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
                raise ValueError(
                    f"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
                )
            model_init_kwargs["torch_dtype"] = torch_dtype

        if train_args.use_liger:
            if is_liger_kernel_available():
                from liger_kernel.transformers import AutoLigerKernelForCausalLM
            else:
                raise ValueError("Liger kernel is not available. Please install it via `pip install liger-kernel`.")
            model = AutoLigerKernelForCausalLM.from_pretrained(model_config.name_or_path, **model_init_kwargs)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_config.name_or_path, **model_init_kwargs)

        return model
