import logging
import random
from pathlib import Path
from typing import Optional

import torch
import transformers
from chronos import ChronosBoltPipeline
from chronos.chronos_bolt import ChronosBoltConfig, ChronosBoltModelForForecasting
from torch.utils.data import DataLoader, IterableDataset
from transformers import Trainer, TrainingArguments
from transformers.models.t5.configuration_t5 import T5Config
from transformers.trainer_utils import seed_worker
from typer_config import use_yaml_config

from chroma.train.dataset import get_training_dataset

logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
LG = logging.getLogger(__file__)
LG.setLevel(logging.INFO)


def reset_non_editable_fields(
    pretrained_chronos_config: dict, training_chronos_config: dict
):
    chronos_config = {}
    for k, v in training_chronos_config.items():
        if k not in pretrained_chronos_config:
            chronos_config[k] = v
            continue

        pretrained_v = pretrained_chronos_config.get(k)

        if pretrained_v is None:
            chronos_config[k] = v
        else:
            if pretrained_v != v:
                LG.warning(
                    f"{k}={v} was provided but {k} is not editable when using a pretrained checkpoint. "
                    f"Resetting {k}={pretrained_v} based on the pretrained checkpoint.",
                    LG,
                )

            chronos_config[k] = pretrained_v

    return chronos_config


def load_model(
    model_class,
    chronos_config: ChronosBoltConfig,
    model_id="google/t5-efficient-tiny",
    random_init=True,
):
    """
    Load the specified HuggingFace model, adjusting the vocabulary
    size, special token IDs, and initialization options.

    This allows to set a model up for training on a new vocabulary
    of tokens.
    """

    config = T5Config.from_pretrained(model_id)
    config.initializer_factor = 0.05
    if random_init:
        LG.info("Using random initialization")
        config.chronos_config = chronos_config.__dict__
        model = model_class(config=config)
    else:
        LG.info(f"Using pretrained initialization from {model_id}")
        config.chronos_config = reset_non_editable_fields(
            getattr(config, "chronos_config", {}), chronos_config.__dict__
        )
        model = model_class.from_pretrained(
            model_id, config=config, ignore_mismatched_sizes=True
        )

    return model


class ChromaTrainer(Trainer):
    def get_train_dataloader(self):
        assert isinstance(self.train_dataset, IterableDataset)
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator

        dataloader_params = {
            "batch_size": self.args.train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
            "drop_last": self.args.dataloader_drop_last,
            "worker_init_fn": seed_worker,
            "prefetch_factor": self.args.dataloader_prefetch_factor,
        }

        return DataLoader(train_dataset, **dataloader_params)


def get_training_arguments(
    max_steps: int,
    learning_rate: float,
    dataloader_num_workers: int,
    seed: int,
    output_path: Path,
    save_steps: int,
) -> TrainingArguments:
    tf32 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8

    training_args = TrainingArguments(
        output_dir=str(output_path),
        per_device_train_batch_size=256,
        learning_rate=learning_rate,
        lr_scheduler_type="linear",
        warmup_ratio=0.0,
        optim="adamw_torch_fused",
        logging_dir=str(output_path / "logs"),
        logging_strategy="steps",
        logging_steps=500,
        save_strategy="steps",
        save_steps=save_steps,
        report_to=["tensorboard"],
        max_steps=max_steps,
        gradient_accumulation_steps=1,
        dataloader_num_workers=dataloader_num_workers,
        tf32=tf32,  # remove this if not using Ampere GPUs (e.g., A100)
        torch_compile=True,
        ddp_find_unused_parameters=False,
        remove_unused_columns=False,
        weight_decay=0.1,
        max_grad_norm=1.0,
        adam_beta1=0.9,
        adam_beta2=0.98,
        seed=seed,
    )

    return training_args


@use_yaml_config()
def train(
    *,
    dataset_names_or_paths: list[str] = ["data/"],
    model_id: str = "google/t5-efficient-tiny",
    max_steps: int = 200_000,
    learning_rate: float = 1e-3,
    random_init: bool = True,
    output_dir: str = "./output/",
    dataloader_num_workers: int = 1,
    seed: Optional[int] = None,
    probabilities: Optional[list[float]] = None,
    preprocess_data: bool = True,
    save_steps: int = 5000,
):
    chronos_config = ChronosBoltConfig(
        context_length=2048,
        prediction_length=64,
        input_patch_size=16,
        input_patch_stride=16,
        quantiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
        use_reg_token=True,
    )

    if seed is None:
        seed = random.randint(0, 2**32)
    LG.info(f"Using SEED: {seed}")
    transformers.set_seed(seed=seed)

    output_path = Path(output_dir)
    if not output_path.exists():
        output_path.mkdir(parents=True)
    else:
        LG.warning(f"Output directory {output_path} already exists. Overwriting.")

    LG.info(f"Logging dir: {output_path}")
    LG.info(
        f"Loading and filtering {len(dataset_names_or_paths)} datasets for training: {dataset_names_or_paths}"
    )

    model = load_model(
        model_class=ChronosBoltModelForForecasting,
        chronos_config=chronos_config,
        model_id=model_id,
        random_init=random_init,
    )

    model.config.chronos_pipeline_class = ChronosBoltPipeline.__name__
    model.config.chronos_config = chronos_config.__dict__

    min_past: int = 64
    shuffled_train_dataset = get_training_dataset(
        dataset_names_or_paths,
        probabilities,
        prediction_length=chronos_config.prediction_length,
        context_length=chronos_config.context_length,
        min_past=min_past,
        preprocess=preprocess_data,
    )

    # Define training args
    training_args = get_training_arguments(
        max_steps=max_steps,
        learning_rate=learning_rate,
        dataloader_num_workers=dataloader_num_workers,
        seed=seed,
        output_path=output_path,
        save_steps=save_steps,
    )

    # Create Trainer instance
    trainer = ChromaTrainer(
        model=model,
        args=training_args,
        train_dataset=shuffled_train_dataset,
    )
    LG.info("Training")

    trainer.train()

    model.save_pretrained(output_path / "checkpoint-final")
    return output_path / "checkpoint-final"
