from __future__ import annotations

import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Callable, cast, override

import torch
from peft import PeftMixedModel, PeftModel
from transformers import TrainerControl, TrainerState, TrainingArguments
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction, SchedulerType
from trl.trainer.sft_config import SFTConfig
from trl.trainer.sft_trainer import SFTTrainer

from datasets import Dataset
from mow.common import defaults
from mow.common.tokenizer import init_model_or_tokenizer
from mow.modules.mow import MoW


@dataclass
class CustomTrainerConfig:
    max_steps: int = field(
        default=1000, metadata={"help": "Maximum number of training steps."}
    )
    eval_steps: int = field(
        default=500, metadata={"help": "Number of steps between evaluations."}
    )
    logging_steps: int = field(
        default=100, metadata={"help": "Number of steps between logging."}
    )
    save_steps: int = field(
        default=100, metadata={"help": "Number of steps between saving."}
    )
    batch_size: int = field(
        default=1, metadata={"help": "Batch size for training."}
    )
    learning_rate: float = field(
        default=5e-5, metadata={"help": "Learning rate for training."}
    )
    lr_scheduler_type: SchedulerType = field(
        default=SchedulerType.LINEAR,
        metadata={"help": "Type of learning rate scheduler."},
    )
    weight_decay: float = field(
        default=0.0, metadata={"help": "Weight decay for optimizer."}
    )
    warmup_steps: int = field(
        default=0, metadata={"help": "Number of warmup steps."}
    )
    gradient_accumulation_steps: int = field(
        default=1,
        metadata={"help": "Number of steps to accumulate gradients."},
    )
    output_dir: Path | None = field(
        default=None, metadata={"help": "Output directory for saving models."}
    )
    run_name: str | None = field(
        default=None, metadata={"help": "Name of the run."}
    )
    logging_dir: Path | str | None = field(
        default=None, metadata={"help": "Directory for saving logs."}
    )

    def copy_with(
        self,
        max_steps: int | None = None,
        eval_steps: int | None = None,
        logging_steps: int | None = None,
        save_steps: int | None = None,
        batch_size: int | None = None,
        learning_rate: float | None = None,
        lr_scheduler_type: SchedulerType | None = None,
        weight_decay: float | None = None,
        warmup_steps: int | None = None,
        gradient_accumulation_steps: int | None = None,
        output_dir: Path | None = None,
        run_name: str | None = None,
        logging_dir: Path | str | None = None,
    ) -> CustomTrainerConfig:
        return CustomTrainerConfig(
            max_steps=max_steps if max_steps is not None else self.max_steps,
            eval_steps=(
                eval_steps if eval_steps is not None else self.eval_steps
            ),
            logging_steps=(
                logging_steps
                if logging_steps is not None
                else self.logging_steps
            ),
            save_steps=(
                save_steps if save_steps is not None else self.save_steps
            ),
            batch_size=(
                batch_size if batch_size is not None else self.batch_size
            ),
            learning_rate=(
                learning_rate
                if learning_rate is not None
                else self.learning_rate
            ),
            lr_scheduler_type=(
                lr_scheduler_type
                if lr_scheduler_type is not None
                else self.lr_scheduler_type
            ),
            weight_decay=(
                weight_decay if weight_decay is not None else self.weight_decay
            ),
            warmup_steps=(
                warmup_steps if warmup_steps is not None else self.warmup_steps
            ),
            gradient_accumulation_steps=(
                gradient_accumulation_steps
                if gradient_accumulation_steps is not None
                else self.gradient_accumulation_steps
            ),
            output_dir=(
                output_dir if output_dir is not None else self.output_dir
            ),
            run_name=run_name if run_name is not None else self.run_name,
            logging_dir=(
                logging_dir if logging_dir is not None else self.logging_dir
            ),
        )


class CustomTrainer(SFTTrainer):
    def __init__(
        self,
        model: PreTrainedModel | PeftModel | PeftMixedModel,
        args: CustomTrainerConfig,
        train_dataset: Dataset,
        eval_dataset: Dataset,
        tokenizer: PreTrainedTokenizerBase,
        compute_metrics: Callable[[EvalPrediction], dict] | None = None,
        data_collator: Callable | None = None,
        callbacks: list[TrainerCallback] | None = None,
        **trainer_kwargs,
    ):
        if args.output_dir is not None:
            output_dir = args.output_dir
        else:
            output_dir = Path("/tmp/hf")

        if args.run_name is not None:
            self.__dirname = args.run_name
        else:
            trials = {0}
            date = datetime.now().strftime("%Y-%m-%d")
            if output_dir.exists():
                for dirname in output_dir.iterdir():
                    if not dirname.is_dir():
                        continue
                    if m := re.match(
                        r"(\d{4}-\d{2}-\d{2})_(\d{3})", dirname.name
                    ):
                        cur_date, cur_trial = m.groups()
                        if cur_date == date:
                            trials.add(int(cur_trial))
            next_trial = max(trials) + 1
            self.__dirname = (
                f"{datetime.now().strftime('%Y-%m-%d')}_{next_trial:03d}"
            )

        self.__output_dir = output_dir / self.__dirname

        if args.logging_steps == 0:
            logging_strategy = "no"
        else:
            logging_strategy = "steps"

        if args.save_steps == 0:
            save_strategy = "no"
        else:
            save_strategy = "steps"

        if args.logging_dir is not None:
            if isinstance(args.logging_dir, Path):
                logging_dir = args.logging_dir
            else:
                logging_dir = Path(
                    self.__output_dir / "logs" / args.logging_dir
                )
        else:
            logging_dir = self.__output_dir / "logs"

        config = SFTConfig(
            output_dir=str(self.__output_dir),
            overwrite_output_dir=True,
            logging_dir=str(logging_dir),
            do_train=True,
            do_eval=True,
            per_device_train_batch_size=args.batch_size,
            per_device_eval_batch_size=args.batch_size,
            eval_strategy="steps",
            logging_strategy=logging_strategy,
            save_strategy=save_strategy,
            max_steps=args.max_steps,
            eval_steps=args.eval_steps,
            logging_steps=args.logging_steps,
            save_steps=args.save_steps,
            learning_rate=args.learning_rate,
            lr_scheduler_type=args.lr_scheduler_type,
            warmup_steps=args.warmup_steps,
            weight_decay=args.weight_decay,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            label_names=[],
            report_to=["tensorboard"],
            dataset_text_field="text",
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            ignore_data_skip=True,
            **trainer_kwargs,
        )

        self.__tokenizer = tokenizer
        self.__model = model
        self.__args = config
        self.__train_dataset = train_dataset
        self.__eval_dataset = eval_dataset

        init_model_or_tokenizer(model=self.__model, tokenizer=self.__tokenizer)  # type: ignore

        if data_collator is None:
            data_collator = defaults.default_data_collator_for_lm(
                tokenizer=self.__tokenizer
            )

        super().__init__(
            model=self.__model,
            tokenizer=tokenizer,
            args=self.__args,
            train_dataset=self.__train_dataset,
            eval_dataset=self.__eval_dataset,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
        )
        self.can_return_loss = True

    def train(  # type: ignore
        self,
        save_model: bool = True,
        model_name: str = "last",
        resume_from_checkpoint: bool = False,
    ):
        super().train(resume_from_checkpoint=resume_from_checkpoint)  # type: ignore
        if save_model:
            self.save_model(str(self.__output_dir / model_name))

    def save_model(self, output_dir: str | Path, **kwargs):  # type: ignore
        if kwargs.pop("relative", False):
            output_dir = str(self.__output_dir / output_dir)
        output_dir = str(output_dir)
        self.__model.save_pretrained(output_dir)
        self.__model.config.save_pretrained(output_dir)  # type: ignore
        self.__tokenizer.save_pretrained(output_dir)

    def test(
        self,
        test_dataset: Dataset,
        interactive: bool = True,
        sample: int = -1,
        use_cache: bool = True,
    ):
        if sample < 0:
            dataset = test_dataset
        else:
            dataset = test_dataset.shuffle(seed=42).select(range(sample))
        for elem in dataset:
            for key, value in elem.items():  # type: ignore
                if key == "text":
                    continue

            prompt = elem["text"]  # type: ignore
            prompt = "assistant".join(prompt.split("assistant")[:-1])
            prompt += "assistant<|end_header_id|>\n\n"

            input_ids = self.__tokenizer(prompt, return_tensors="pt")
            input_ids = input_ids.to(self.__model.device)  # type: ignore

            pred = self.__model.generate(
                **input_ids,  # type: ignore
                max_length=4096,
                pad_token_id=self.__tokenizer.pad_token_id,
                eos_token_id=self.__tokenizer.eos_token_id,
                tokenizer=self.__tokenizer,
                use_cache=use_cache,
            )[0]
            pred_decoded = (
                self.__tokenizer.decode(pred, skip_special_tokens=False)
                .split("assistant<|end_header_id|>")[-1]
                .split("<|eot_id|>")[0]
                .strip()
            )
            print("Predicted:", pred_decoded)

            if interactive:
                cmd = input()
                if cmd == "quit":
                    break


class PredictionCallback(TrainerCallback):
    def __init__(
        self, model: MoW, dataset: Dataset, config: CustomTrainerConfig
    ):
        self.gen = self.on_step_end_generate()

        self.model = model
        self.dataset = dataset

        assert (
            config.output_dir is not None
        ), "Output directory must be specified in the training config."

        self.output_dir = config.output_dir

    def on_step_end_generate(self):
        while True:
            for example in self.dataset:
                example = cast(dict, example)

                hidden_states: torch.Tensor = example["nodes"]
                adj_mat: torch.Tensor = example["adjacency_matrix"]
                rel: torch.Tensor = example["relation_matrix"]
                context: torch.Tensor = example["context"]

                prompt: str = example["text"]
                *splits, answer = prompt.split("assistant<|end_header_id|>")
                prompt = "assistant<|end_header_id|>".join(splits)
                prompt += "assistant<|end_header_id|>\n\n"
                answer = answer.split("<|eot_id|>")[0].strip()

                input_ids = self.model.tokenizer(
                    prompt, return_tensors="pt"
                ).input_ids

                hidden_states = hidden_states.to(self.model.device)
                adj_mat = adj_mat.to(self.model.device)
                rel = rel.to(self.model.device)
                context = context.to(self.model.device)
                input_ids = input_ids.to(self.model.device)

                output = self.model.generate(
                    hidden_states=hidden_states,
                    adjacency_matrix=adj_mat,
                    relation_matrix=rel,
                    context=context,
                    input_ids=input_ids,
                    max_length=input_ids.shape[1] + 10,
                )
                pred = self.model.tokenizer.decode(
                    output[0], skip_special_tokens=False
                )
                pred = pred.split("assistant")[-1]
                if "<|end_header_id|>" in pred:
                    pred = pred.split("<|end_header_id|>")[1]
                if "<|eot_id|>" in pred:
                    pred = pred.split("<|eot_id|>")[0]
                pred = pred.strip()

                with open(self.output_dir / "predictions.jsonl", "a") as f:
                    s = json.dumps(
                        {
                            "input": prompt,
                            "expected_output": answer,
                            "predicted_output": pred,
                        },
                    )
                    f.write(s + "\n")

                yield

    @override
    def on_step_end(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        if (
            args.eval_steps is not None
            and (state.global_step - 1) % args.eval_steps == 0
        ):
            next(self.gen, None)
        return super().on_step_end(args, state, control, **kwargs)
