import copy
import logging
import math
import os
import shutil
from collections import deque
from dataclasses import field, dataclass
from typing import Optional, Sequence

import torch
import transformers
from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments as HfTrainingArguments,
    AutoModelForCausalLM,
    get_scheduler,
    PreTrainedModel,
    PreTrainedTokenizer,
)

from src.data_conversion import load_truth_file
from src.prompt_functions import prepare_source_and_target
from src.utility import DataMode

IGNORE_INDEX = -100

logger = logging.getLogger(__name__)

random_seed = 666


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = None


@dataclass
class DataArguments:
    truth_paths: list[str] = field(
        default=None, metadata={"help": "Path to truth files."}
    )
    data_gen_mode: int = field(default=1, metadata={"help": "Data generation mode."})
    chat_format: str = field(
        default="general",
        metadata={
            "help": "General chat messages or format that need apply_chat_template."
        },
    )
    split_ratio: float = field(
        default=0.95, metadata={"help": "Split the train/validation set"}
    )


@dataclass
class TrainingArguments(HfTrainingArguments):
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Cache dir to use."}
    )
    optim: str = field(default="adamw_torch", metadata={"help": "Optimizer."})
    model_max_length: int = field(
        default=512, metadata={"help": "Maximum sequence length."}
    )
    log_dir: str = field(default=None, metadata={"help": "Log dir."})
    save_best_model_at_the_end: bool = field(
        default=True, metadata={"help": "Save best model to output dir at the end."}
    )
    early_stopping_threshold: float = field(
        default=0.0005, metadata={"help": "Min change to eval loss"}
    )
    early_stopping_patience: int = field(
        default=5, metadata={"help": "Validation steps to wait until early stop"}
    )


## ------------------- Data Preprocessing and Loading ------------------- ##


def preprocess(
    sources: Sequence[str],
    targets: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> dict:
    examples = [s + t for s, t in zip(sources, targets)]

    # Tokenize combined examples and sources in batches. No padding is applied here.
    examples_tokenized = tokenizer(
        examples, truncation=True, max_length=tokenizer.model_max_length
    )["input_ids"]
    sources_tokenized = tokenizer(
        sources, truncation=True, max_length=tokenizer.model_max_length
    )["input_ids"]

    input_ids_list = []
    labels_list = []

    for i in range(len(examples_tokenized)):
        source_len = len(sources_tokenized[i])
        example_ids = examples_tokenized[i]

        labels = copy.deepcopy(example_ids)
        labels[:source_len] = [IGNORE_INDEX] * source_len

        input_ids_list.append(torch.tensor(example_ids, dtype=torch.long))
        labels_list.append(torch.tensor(labels, dtype=torch.long))

    return dict(input_ids=input_ids_list, labels=labels_list)


class ToolCallingDataset(Dataset):
    def __init__(
        self,
        truth_paths: list[str],
        mode: DataMode,
        tokenizer: transformers.PreTrainedTokenizer,
        chat_format: str,
    ):
        """
        Create Dataset for training.
        :param truth_paths: Path to truth files.
        :param mode: Train model to generate API call directly or a summarized text.
        :param tokenizer: Tokenizer to use.
        :param chat_format: General chat messages or format that need apply_chat_template.
        """
        super(ToolCallingDataset, self).__init__()
        logger.info(f"Loading data with {mode=}, {truth_paths=}")
        truth = {}

        for truth_path in truth_paths:
            content = load_truth_file(truth_path)
            for k, v in content.items():
                if v.api_call is None:
                    continue
                truth[k] = v

        logger.info("Formatting and tokenizing inputs...")
        sources, targets = prepare_source_and_target(
            mode=mode, truth=truth, chat_format=chat_format, tokenizer=tokenizer
        )
        data_dict = preprocess(sources, targets, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        logger.info(f"Loaded {len(self)} examples.")

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass
class DataCollator:

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[dict]) -> dict[str, torch.Tensor]:
        input_ids = [instance["input_ids"] for instance in instances]
        labels = [instance["labels"] for instance in instances]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=IGNORE_INDEX
        )
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def make_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> dict:
    """Creates train/eval datasets and a data collator."""
    full_dataset = ToolCallingDataset(
        tokenizer=tokenizer,
        truth_paths=data_args.truth_paths,
        mode=DataMode(data_args.data_gen_mode),
        chat_format=data_args.chat_format,
    )

    n_train = int(len(full_dataset) * data_args.split_ratio)
    train_dataset, eval_dataset = torch.utils.data.random_split(
        full_dataset,
        [n_train, len(full_dataset) - n_train],
        generator=torch.Generator().manual_seed(random_seed),
    )
    logger.info(f"Split dataset: {len(train_dataset)=}, {len(eval_dataset)=}")

    return dict(
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DataCollator(tokenizer=tokenizer),
    )


def create_optimizer(name, params, lr=5e-5, weight_decay=0.0):
    name = name.lower()
    if name in ["adamw", "adamw_torch"]:
        optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    elif name == "adam":
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif name == "sgd":
        optimizer = torch.optim.SGD(
            params, lr=lr, weight_decay=weight_decay, momentum=0.9
        )
    else:
        raise ValueError(f"Unsupported optimizer: {name}")
    return optimizer


def save_model(
    accelerator: Accelerator,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    ckpt_dir: str,
):
    """
    Save model when using FSDP.
    :param accelerator: Accelerator object used for training.
    :param model: HF model already prepared by accelerator.
    :param tokenizer: Tokenizer to use.
    :param ckpt_dir: Dir to save the checkpoint.
    """
    unwrapped_model = accelerator.unwrap_model(model)
    state_dict = accelerator.get_state_dict(model)
    unwrapped_model.save_pretrained(
        ckpt_dir,
        is_main_process=accelerator.is_main_process,
        save_function=accelerator.save,
        state_dict=state_dict,
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(ckpt_dir)


def train():
    """Train model using Accelerator with FSDP."""
    # env setup
    transformers.set_seed(random_seed)

    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=training_args.log_dir,
    )
    accelerator.init_trackers(project_name="tool_calling")
    logging.basicConfig(
        level=logging.INFO if accelerator.is_main_process else logging.ERROR,
        format="[%(asctime)s][%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logger.info(f"Start training with:{model_args=}, {data_args=}, {training_args=}")

    # load data set and tokenizer
    with accelerator.main_process_first():
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            model_max_length=training_args.model_max_length,
            use_fast=False,
            padding_side="left",
        )
        logger.info(
            f"Tokenizer: {tokenizer.name_or_path}, size: {len(tokenizer.get_vocab())}, "
            f"model_max_length:{tokenizer.model_max_length}, padding_side: {tokenizer.padding_side}"
        )

        if not tokenizer.pad_token:
            logger.info(f"Using eos_token {tokenizer.eos_token} as pad_token")
            tokenizer.pad_token = tokenizer.eos_token
        if not tokenizer.bos_token:
            logger.info(f"Using eos_token {tokenizer.eos_token} as bos_token")
            tokenizer.bos_token = tokenizer.eos_token

        data_module = make_data_module(tokenizer=tokenizer, data_args=data_args)

        train_dataloader = DataLoader(
            data_module["train_dataset"],
            shuffle=True,
            collate_fn=data_module["data_collator"],
            batch_size=training_args.per_device_train_batch_size,
        )
        eval_dataloader = DataLoader(
            data_module["eval_dataset"],
            collate_fn=data_module["data_collator"],
            batch_size=training_args.per_device_eval_batch_size,
        )

    train_dataloader, eval_dataloader = accelerator.prepare(
        train_dataloader, eval_dataloader
    )

    # calculate training parameters, accelerator already take care of distributed env
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / training_args.gradient_accumulation_steps
    )

    if training_args.max_steps <= 0:
        training_args.max_steps = int(
            training_args.num_train_epochs * num_update_steps_per_epoch
        )
    warmup_step = int(training_args.max_steps * training_args.warmup_ratio)
    logger.info(f"max step: {training_args.max_steps}, warm up step: {warmup_step}")

    # create model, etc.
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        use_cache=False,
    )
    logger.info(f"Using model: {model}")

    optimizer = create_optimizer(
        training_args.optim,
        model.parameters(),
        lr=training_args.learning_rate,
        weight_decay=training_args.weight_decay,
    )
    logger.info(f"Using optimizer: {optimizer}")

    lr_scheduler = get_scheduler(
        name=training_args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=warmup_step,
        num_training_steps=training_args.max_steps,
    )
    logger.info(f"Using lr_scheduler: {lr_scheduler}")

    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)

    saved_ckpt = deque(maxlen=training_args.save_total_limit)

    logger.info("***** Starting training loop *****")
    progress_bar = tqdm(
        range(training_args.max_steps), disable=not accelerator.is_local_main_process
    )
    effective_step = 0
    best_eval_loss = float("inf")
    patience_counter = 0
    stop_training = False

    for epoch in range(int(training_args.num_train_epochs)):
        model.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)

                # gradient step
                if accelerator.sync_gradients:
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

                    progress_bar.update(1)
                    effective_step += 1

                    if effective_step % training_args.logging_steps == 0:
                        avg_loss = accelerator.gather(loss).mean()
                        current_lr = lr_scheduler.get_last_lr()[0]

                        logger.info(
                            f"[Train-{effective_step}] loss: {avg_loss.item():.4f}, LR: {current_lr}"
                        )
                        log_obj = {
                            "train_loss": avg_loss.item(),
                            "lr": current_lr,
                            "effective_step": effective_step,
                        }
                        accelerator.log(log_obj, step=effective_step)

                    if effective_step % training_args.eval_steps == 0:
                        model.eval()
                        losses = []
                        for eval_batch in eval_dataloader:
                            with torch.no_grad():
                                outputs = model(**eval_batch)
                            losses.append(accelerator.gather_for_metrics(outputs.loss))
                        eval_loss = torch.cat(losses).mean()
                        logger.info(
                            f"[Eval-{effective_step}] loss: {eval_loss.item():.4f}"
                        )
                        log_obj = {
                            "eval_loss": eval_loss.item(),
                            "effective_step": effective_step,
                        }
                        accelerator.log(log_obj, step=effective_step)

                        # check if this is the best model
                        if (
                            eval_loss
                            < best_eval_loss - training_args.early_stopping_threshold
                        ):
                            best_eval_loss = eval_loss
                            patience_counter = 0

                            ckpt_dir = os.path.join(
                                training_args.output_dir, f"checkpoint-{effective_step}"
                            )
                            logger.info(f"New best model! Saving to {ckpt_dir}")

                            # save the new checkpoint
                            accelerator.wait_for_everyone()
                            save_model(accelerator, model, tokenizer, ckpt_dir)

                            # remove the oldest checkpoint
                            if accelerator.is_main_process:
                                if training_args.save_total_limit is not None:
                                    if (
                                        len(saved_ckpt)
                                        == training_args.save_total_limit
                                    ):
                                        ckpt = saved_ckpt.popleft()
                                        logger.info(f"Removing old checkpoint: {ckpt}")
                                        shutil.rmtree(ckpt)
                                    saved_ckpt.append(ckpt_dir)
                            accelerator.wait_for_everyone()
                        else:
                            patience_counter += 1

                        model.train()
            # End of one batch, check for stop conditions
            stop_training = False
            if patience_counter >= training_args.early_stopping_patience:
                logger.info(
                    f"Limit: {training_args.early_stopping_patience} reached, stopping early."
                )
                stop_training = True

            if effective_step >= training_args.max_steps:
                logger.info(f"Max steps: {training_args.max_steps} reached, stopping.")
                stop_training = True

            if stop_training:
                break

        if stop_training:
            break

    accelerator.wait_for_everyone()
    logger.info("Training finished.")

    if training_args.save_best_model_at_the_end:
        if accelerator.is_main_process:
            best_ckpt = saved_ckpt[-1]
            logger.info(f"Copy best model {best_ckpt} to {training_args.output_dir}")
            shutil.copytree(best_ckpt, training_args.output_dir, dirs_exist_ok=True)
    else:
        logger.info(f"Saving model to {training_args.output_dir}.")
        accelerator.wait_for_everyone()
        save_model(accelerator, model, tokenizer, training_args.output_dir)
        accelerator.wait_for_everyone()

    accelerator.end_training()


if __name__ == "__main__":
    train()
