import argparse
import enum
import functools
import os
import pathlib
import typing

import accelerate
import datasets
import dotenv
import numpy as np
import safetensors
import torch
import transformers
import xtuner.utils
from einops import rearrange

import harmon.src.datasets.utils
import wandb
from divergent_memories.data import builders as _builders
from divergent_memories.harmon import utils as _utils

IGNORE_INDEX = -100


def main() -> None:
    dotenv.load_dotenv()

    # Accelerator has to be initialized before parsing configs
    # (in case HF stuff is used dynamically)
    accelerator = accelerate.Accelerator()

    # Fix verbose logging in multi-process setting
    if not accelerator.is_main_process:
        transformers.logging.set_verbosity_error()
        transformers.utils.logging.disable_progress_bar()
        datasets.disable_progress_bars()

    args = parse_args()
    output_model_id = args.output_model_id
    setting = args.setting
    aux_fraction = args.aux_fraction
    max_steps = args.max_steps
    warmup_steps = args.warmup_steps
    learning_rate = args.learning_rate
    weight_decay = args.weight_decay
    save_strategy = args.save_strategy
    save_steps = args.save_steps
    save_final_model = args.save_final_model

    # Custom wandb init for additional config options
    if accelerator.is_main_process:
        os.environ.setdefault("WANDB_PROJECT", "train_generation")
        os.environ.setdefault("WANDB_LOG_MODEL", "false")
        os.environ.setdefault("WANDB_WATCH", "false")
        wandb.init(
            name=output_model_id,
            config={
                "setting": setting,
                "aux_fraction": aux_fraction,
                "save_final_model": save_final_model,
            },
        )

    training_args = transformers.TrainingArguments(
        output_dir=args.output_root / output_model_id,
        seed=args.seed,
        max_steps=max_steps,
        per_device_train_batch_size=32,  # TODO: Double check that matches!
        gradient_accumulation_steps=1,
        learning_rate=learning_rate,
        adam_beta1=0.9,
        adam_beta2=0.95,
        weight_decay=weight_decay,
        max_grad_norm=1.0,
        lr_scheduler_type="cosine",
        warmup_steps=warmup_steps,
        # TODO: Harmon uses a custom optimizer that only performs weight decay on 2D weights or smth;
        #  not sure if this breaks everything.
        optim="adamw_torch_fused",
        bf16=True,
        logging_steps=10,
        # eval_strategy="no",  # TODO: fix model to actually get eval metrics (loss)
        eval_strategy="steps",
        eval_steps=50,
        save_strategy=save_strategy,
        save_steps=save_steps,
        # metric_for_best_model="eval_val_loss",  # TODO: fix if no val data available
        greater_is_better=False,
        remove_unused_columns=False,
        # save_safetensors=False
    )

    accelerator.print("Building model")

    if args.base_model_dir is not None:
        checkpoint_path = args.base_model_dir / "model.safetensors"
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
    else:
        checkpoint_path = None
    model, tokenizer = _utils.build_model(use_dev_model=True, checkpoint_path=checkpoint_path)
    model.mar.gradient_checkpointing_disable()
    model.llm.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    model = model.train()

    # Build datasets
    accelerator.print("Building datasets")
    builder = _builders.ImageUnderstandingBuilder(args.data_root, args.seed)
    if args.setting == SettingType.SYNTHETIC_CONCEPTS:
        train_dataset, val_datasets = builder.build_concepts_train_val(
            use_hd=True,
            aux_fraction=aux_fraction,
        )
        loss_type = "image2text"
    elif args.setting == SettingType.TINY_MMLU:
        train_dataset, val_datasets = builder.build_tiny_mmlu_train()
        loss_type = "text2text"
    elif args.setting == SettingType.MMMU:
        train_dataset, val_datasets = builder.build_mmmu_train_val()
        loss_type = "image2text"
    else:
        raise ValueError(f"Invalid setting: {args.setting}")

    # Process dataset
    accelerator.print("Processing datasets")
    train_dataset = build_dataset(
        train_dataset,
        tokenizer=tokenizer,
        crop_image=True,
        image_length=1024 + 64,  # TODO: Make this more sane!
        image_size=512,
        prompt_template=xtuner.utils.PROMPT_TEMPLATE.qwen_chat,
        max_length=4096,
    )
    # TODO: Move processing to collator! (see train_generation.py)
    val_datasets = {
        key: build_dataset(
            dataset,
            tokenizer=tokenizer,
            crop_image=True,
            image_length=1024 + 64,
            image_size=512,
            prompt_template=xtuner.utils.PROMPT_TEMPLATE.qwen_chat,
            max_length=4096,
        )
        for key, dataset in val_datasets.items()
    }

    data_collator = functools.partial(collate_func_und, pad_index=tokenizer.pad_token_id, loss_type=loss_type)

    accelerator.print("Starting training")
    trainer = transformers.Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
        eval_dataset=val_datasets,
    )
    trainer.train()
    if save_final_model and accelerator.is_main_process:
        trainer.save_state()
        # trainer.save_model(output_dir=training_args.output_dir)
        safetensors.torch.save_model(model=model, filename=training_args.output_dir + "/model.safetensors")

    accelerator.wait_for_everyone()
    accelerator.end_training()


# TODO: Is pad_index correct?
def collate_func_und(instances: typing.Sequence[typing.Dict], pad_index: int, loss_type: str):
    pixel_values, input_ids, labels, input_lengths = [], [], [], []
    for example in instances:
        if "pixel_values" in example:
            pixel_values.append(example.pop("pixel_values"))
        input_lengths.append(len(example["input_ids"]))
        input_ids.append(example.pop("input_ids"))
        labels.append(example.pop("labels"))

    # TODO: Test left padding!
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
    attention_mask = torch.zeros_like(input_ids).bool()
    for i in range(len(input_ids)):
        attention_mask[i, : input_lengths[i]] = True

    collated = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "loss_type": loss_type,
    }

    if len(pixel_values) > 0:
        collated["pixel_values"] = torch.stack(pixel_values)

    return collated


def build_dataset(
    raw_dataset: datasets.Dataset,
    tokenizer: transformers.AutoTokenizer,
    crop_image: bool,
    image_length: int,
    image_size: int,
    prompt_template: dict[str, str],
    max_length: int,
):  # TODO: Clean up and test if this also works in collator
    def _process_sample(sample: dict[str, typing.Any]):
        input_ids, labels = _process_text(sample["prompt"], sample["completion"], include_image=("image" in sample))

        processed = {
            "input_ids": input_ids,
            "labels": labels,
        }

        if "image" in sample:
            processed["pixel_values"] = _process_image(sample["image"])

        return processed

    def _process_text(prompt: str, completion: str, include_image: bool):
        if include_image:
            prompt = prompt_template["INSTRUCTION"].format(
                input=_utils.DEFAULT_IMAGE_TOKEN * image_length + "\n" + prompt
            )
        else:
            prompt = prompt_template["INSTRUCTION"].format(input=prompt)
        prompt_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")

        completion_ids = tokenizer.encode(
            completion + prompt_template["SUFFIX"], add_special_tokens=True, return_tensors="pt"
        )

        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)[0]

        if len(input_ids) > max_length:
            raise ValueError(f"Input ids length {len(input_ids)} exceeds max length {max_length}.")

        # Assistant-only loss
        labels = input_ids.clone()
        labels[:prompt_ids.shape[1]] = IGNORE_INDEX

        return input_ids, labels

    def _process_image(image):
        if crop_image:
            image = harmon.src.datasets.utils.crop2square(image)
        else:
            target_size = max(image.size)
            image = image.resize(size=(target_size, target_size))

        # Make sure image is RGB
        image = image.convert("RGB")

        image = image.resize(size=(image_size, image_size))
        pixel_values = torch.from_numpy(np.array(image)).float()
        pixel_values = pixel_values / 255
        pixel_values = 2 * pixel_values - 1
        pixel_values = rearrange(pixel_values, "h w c -> c h w")

        return pixel_values

    dataset = raw_dataset.map(
        _process_sample,
        batched=False,
        keep_in_memory=True,
        load_from_cache_file=False,
    )

    if "image" in raw_dataset.column_names:
        dataset.set_format(type="torch", columns=["pixel_values", "input_ids", "labels"], output_all_columns=True)
    else:
        dataset.set_format(type="torch", columns=["input_ids", "labels"], output_all_columns=True)

    return dataset


class SettingType(enum.Enum):
    SYNTHETIC_CONCEPTS = "synthetic_concepts"
    TINY_MMLU = "tiny_mmlu"
    MMMU = "mmmu"

    def __str__(self) -> str:
        return self.value


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--setting",
        type=SettingType,
        required=True,
        help="Type of experiment",
        choices=list(SettingType),
    )
    parser.add_argument(
        "--output-model-id",
        type=str,
        required=True,
        help="Model name and wandb run name",
    )
    parser.add_argument(
        "--base-model-dir",
        type=pathlib.Path,
        required=False,
        help="Path to the base model directory. If not provided, will use the default model from Hugging Face.",
    )
    parser.add_argument(
        "--data-root",
        type=pathlib.Path,
        default=pathlib.Path(os.getenv("DATA_ROOT", default=pathlib.Path.cwd() / "data")),
    )
    parser.add_argument(
        "--output-root",
        type=pathlib.Path,
        default=pathlib.Path(os.getenv("MODEL_OUTPUT_ROOT", default=pathlib.Path.cwd() / "models")),
    )
    parser.add_argument("--seed", type=int, default=178430, help="Random seed")
    parser.add_argument("--save-strategy", type=str, default="no", help="Save strategy")
    parser.add_argument("--save-steps", type=int, default=100, help="Save steps")
    parser.add_argument(
        "--save-final-model", action="store_true", help="Save final model (independent of save strategy)"
    )

    parser.add_argument(
        "--aux-fraction",
        type=float,
        required=True,
        help="Auxiliary data fraction (in relation to raw synthetic images)",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=1000,
        help="Maximum number of steps",
    )
    parser.add_argument(
        "--warmup-steps",
        type=int,
        default=10,
        help="Number of warmup steps",
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=1e-5,
        help="Learning rate",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=0.02,
        help="Weight decay",
    )

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    main()
