import logging
import os
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Iterable, Optional, cast

import pandas as pd
import wandb
from datasets import DatasetDict
from torch.utils.data import Dataset
from transformers import (
    DataCollatorForLanguageModeling,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    Trainer,
    TrainerCallback,
)
from transformers import TrainingArguments as TA
from transformers import set_seed

from lib_dl_base.defs.task_id import TaskID
from lib_dl_base.io.dirs import append_directories, get_artifacts_dir
from lib_dl_base.visualization.progress import should_show_progressbar

from ..models.load import ModelConfig, load_model
from .preprocess import preprocess_dataset


logger = logging.getLogger(__name__)
ROOT_PATH = Path(__file__).parent.parent

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))


@dataclass(kw_only=True)
class TrainingArguments:
    """A subset of the transformers.TrainingArguments used in our experiments."""

    learning_rate: float
    num_train_epochs: int | float
    lr_scheduler_type: str = "linear"
    warmup_steps: int = 0
    weight_decay: float = 0.0
    optim: str = "adamw_torch"
    # max_grad_norm: float | None = 1.0
    # bf16: bool = False
    # bf16_full_eval: bool = False
    # fp16_full_eval: bool = False
    deepspeed: str | Path | None = None
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    evaluation_strategy: str = "no"
    eval_steps: int | float | None = None
    logging_steps: int = 500
    save_strategy: str = "no"
    report_to: list[str] = field(default_factory=lambda: ["wandb"])


@dataclass
class TrainingConfig:
    seed: int
    args: TrainingArguments
    train: bool = True
    save_final_checkpoint: bool = True
    wandb_project_name: str = "llms"
    resume_path: Path | None = None


@dataclass
class TrainingResult:
    model: PreTrainedModel
    training_log: Optional[pd.DataFrame]


def train(
    task_id: TaskID,
    model_info: tuple[str, PreTrainedModel | None],
    dataset_info: tuple[str, DatasetDict | None],
    config: TrainingConfig,
    tokenizer: PreTrainedTokenizerBase,
    *,
    callbacks: list[TrainerCallback] | None = None,
    run_callbacks_initially: bool = True,
    data_already_preprocessed: bool = True,
    set_subdirs: Iterable[str] = ["action", "model", "dataset"],
    output_dir: Path | None = None,
) -> TrainingResult:
    set_seed(config.seed)
    # os.environ["TOKENIZERS_PARALLELISM"] = "false"

    model_name, model = model_info
    dataset_name, dataset = dataset_info

    task_description = task_id
    if "action" in set_subdirs:
        task_description = task_description.set_action("training")
    if "model" in set_subdirs:
        task_description = task_description.set_model(model_name)
    if "dataset" in set_subdirs:
        task_description = task_description.set_dataset(dataset_name)
    if output_dir is None:
        output_dir = append_directories(get_artifacts_dir(), task_description)

    # local_output_dir += f'-pdbs{config.per_device_train_batch_size}'
    # if gradient_accumulation_steps > 1:
    #     local_output_dir += f'-gacc{gradient_accumulation_steps}'

    if config.train:
        if model is None:
            raise ValueError("Must provide a model to train")
        if dataset is None:
            raise ValueError("Must provide a dataset to train on")
        return _train(
            model=model,
            dataset=dataset,
            tokenizer=tokenizer,
            config=config,
            task_id=task_description,
            output_dir=output_dir,
            callbacks=callbacks,
            run_callbacks_initially=run_callbacks_initially,
            data_already_preprocessed=data_already_preprocessed,
        )
    else:
        model = _load_model(output_dir, config.args.num_train_epochs)

        return TrainingResult(model=model, training_log=None)


def _train(
    model: PreTrainedModel,
    dataset: DatasetDict,
    tokenizer: PreTrainedTokenizerBase,
    config: TrainingConfig,
    task_id: TaskID,
    output_dir: Path,
    *,
    callbacks: list[TrainerCallback] | None = None,
    run_callbacks_initially: bool = True,
    data_already_preprocessed: bool = True,
) -> TrainingResult:
    os.environ["WANDB_WATCH"] = "all"
    os.environ["WANDB_PROJECT"] = config.wandb_project_name
    wandb_run = None
    # try:
    if "wandb" in config.args.report_to and LOCAL_RANK == 0:
        wandb_run = wandb.init(
            project=config.wandb_project_name,
            name=task_id.name,
            reinit=True,
        )
    # except wandb.sdk.service.service.ServiceStartTimeoutError:
    #     logger.warning("Could not connect to wandb. Continuing without it.")

    training_args = TA(
        output_dir=str(output_dir),
        logging_dir=str(output_dir / "runs"),
        local_rank=LOCAL_RANK,
        run_name=task_id.name,
        disable_tqdm=not should_show_progressbar(),
        **asdict(config.args),
    )

    if data_already_preprocessed:
        processed_dataset = dataset
    else:
        processed_dataset = preprocess_dataset(
            dataset=dataset,
            tokenizer=tokenizer,
            seed=config.seed,
        )
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        return_tensors="pt",
        pad_to_multiple_of=8,
    )

    data_type = next(model.parameters()).dtype
    logger.info(f"The model is loaded in {data_type} format.")

    training_dataset = cast(Dataset, processed_dataset["train"])
    eval_dataset = cast(Dataset, processed_dataset["test"])
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=training_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        callbacks=callbacks,
    )
    if callbacks is not None and run_callbacks_initially:
        logger.info("Computing initial metrics")
        trainer_state = deepcopy(trainer.state)
        trainer_state.epoch = 0
        for callback in callbacks:
            callback.on_evaluate(
                trainer.args,
                trainer_state,
                trainer.control,
                model=model,
                tokenizer=tokenizer,
            )

    logger.info("Training")
    trainer.train(
        resume_from_checkpoint=(
            str(config.resume_path) if config.resume_path is not None else None
        )
    )

    end_state = trainer.state
    eval_steps = trainer.args.eval_steps
    if (
        training_args.evaluation_strategy == "steps"
        and eval_steps is not None
        and eval_steps > 1
        and end_state.global_step % eval_steps != 0
    ):
        logger.info("Computing final metrics")
        trainer.evaluate()

    if config.save_final_checkpoint:
        final_output_dir = get_final_output_dir(
            output_dir,
            config.args.num_train_epochs,
        )
        logger.info(f"Saving Model to {final_output_dir}")
        trainer.save_model(output_dir=str(final_output_dir))

    if wandb_run is not None:
        wandb_run.finish()

    training_history = pd.DataFrame(trainer.state.log_history)
    logger.info("Done training.")

    return TrainingResult(
        model=model,
        training_log=training_history,
    )


def _load_model(
    output_dir: Path,
    epochs: int | float,
) -> PreTrainedModel:
    return load_model(
        ModelConfig(
            model_id=None,
            base_dir=get_final_output_dir(output_dir, epochs),
        )
    )


def get_final_output_dir(
    output_dir: Path,
    epochs: int | float,
) -> Path:
    return output_dir / f"epoch_{epochs}_checkpoint"
