import os

import hydra
import torch
import torch.distributed as dist
import torch.optim as optim
import wandb
from configs.template import MainConfig
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf
from src.dataset import TSPDataset
from src.model import Model
from src.trainer import Trainer
from torch.multiprocessing.spawn import spawn
from torch.nn.parallel import DistributedDataParallel as DDP

# Some torch warning told me to do so.
torch.set_float32_matmul_precision("high")

# `torch.vmap` seems to be raising a lot of recompilations. We raise the recompilation limit to try
# to capture them all.
torch._dynamo.config.recompile_limit = 32


def launch_training(rank: int, world_size: int, dict_config: DictConfig):
    config = MainConfig.from_dict(dict_config)
    device = torch.device(f"cuda:{rank}")

    if world_size > 1:  # Init DDP.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12355"
        dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    # Datasets.
    training = TSPDataset.from_npz(config.training_dataset.filepath)
    evaluations = [TSPDataset.from_npz(data.filepath) for data in config.evaluation_datasets]

    # Model, optimizer & trainer states.
    model = Model(
        config.model.hidden_dim,
        config.model.ff_dim,
        config.model.n_heads,
        config.model.n_layers,
        use_alibi=config.model.use_alibi,
        use_coords=config.model.use_coords,
        use_random_ids=config.model.use_random_ids,
        use_rope=config.model.use_rope,
        use_ssmax=config.model.use_ssmax,
    ).to(device)
    model = DDP(model, device_ids=[rank]) if world_size > 1 else model
    optimizer = optim.AdamW(model.parameters(), config.trainer.learning_rate)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.trainer.training_iters, eta_min=config.trainer.learning_rate_min
    )
    trainer = Trainer(
        device,
        config.trainer.evaluation_batch_size,
        config.trainer.evaluation_every,
        config.trainer.evaluation_iters,
        config.trainer.training_batch_size // world_size,
        config.trainer.training_iters,
        world_size,
    )

    match config.trainer.resume:
        case None:
            initial_iter = 0
            run_id = None
            resume = "never"
        case _:
            checkpoint = torch.load(config.trainer.resume / "checkpoint.pth", map_location=device)
            if world_size > 1:
                model.module.load_state_dict(checkpoint["model-state"])
            else:
                model.load_state_dict(checkpoint["model-state"])
            optimizer.load_state_dict(checkpoint["optimizer-state"])
            scheduler.load_state_dict(checkpoint["scheduler-state"])

            run_id = checkpoint["run-id"]
            initial_iter = checkpoint["iter-id"] + 1
            resume = "must"

    match rank:
        case 0:  # Only the first device will log.
            run = wandb.init(
                project="tsp-equivariant",
                group=config.wandb.group,
                config=OmegaConf.to_container(dict_config),
                entity=config.wandb.entity,
                mode=config.wandb.mode,
                id=run_id,
                resume=resume,
            )
        case _:
            run = None

    try:
        trainer.train(model, optimizer, scheduler, training, evaluations, initial_iter, run)
    finally:
        if run is not None:
            run.finish()

        if world_size > 1:
            dist.destroy_process_group()


@hydra.main(config_path="configs", config_name="default", version_base="1.3")
def main(dict_config: DictConfig):
    dict_config["data"]["training"] = to_absolute_path(dict_config["data"]["training"])
    dict_config["data"]["evaluation"] = [
        to_absolute_path(p) for p in dict_config["data"]["evaluation"]
    ]
    if dict_config["trainer"]["resume"] is not None:
        dict_config["trainer"]["resume"] = to_absolute_path(dict_config["trainer"]["resume"])

    assert torch.cuda.device_count() >= 1, "No GPU detected"

    match torch.cuda.device_count():
        case 1:
            launch_training(rank=0, world_size=1, dict_config=dict_config)
        case world_size:
            spawn(launch_training, args=(world_size, dict_config), nprocs=world_size)


if __name__ == "__main__":
    main()
