import json
import argparse
import os
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from imagenet_pretrain.dataset import ImageNetDataLoader
from imagenet_pretrain.model import ResNet
from common.sam import SAM, disable_running_stats, enable_running_stats

try:
    from apex import amp, parallel

    apex_available = True
except ModuleNotFoundError:
    apex_available = False
    print(
        "Apex module not found. Defaulting to torch.nn.parallel.DistributedDataParallel for multi-gpu."
    )


def parse_args():
    parser = argparse.ArgumentParser()
    # Dataset
    parser.add_argument("--dataset_file", type=str)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--layers", type=int, default=18)
    parser.add_argument("--num_workers", type=int, default=0)

    # Training
    parser.add_argument(
        "--training_steps",
        type=int,
        default=450000,
        help="Total number of gradient steps. Default is 450000, equivalent to 90 epochs of ImageNet.",
    )
    parser.add_argument(
        "--gradient_step_modulo",
        type=int,
        default=1,
        help="Number of badutches to accumulate gradient before taking gradient step.",
    )

    # Model
    parser.add_argument("--model", type=str)

    # Checkpointing
    parser.add_argument(
        "--checkpoint_modulo",
        type=int,
        default=4500,
        help="Save checkpoint every x gradient step.",
    )
    parser.add_argument("--checkpoint_directory", type=str, default=".")

    # Device
    parser.add_argument("--sam", action="store_true")
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("--multi_gpu", action="store_true")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=0,
        help="Do not specify this argument. torch.distributed.launch sets this value. "
        "Rank 0 handles all main thread operations like logging and checkpointing.",
    )

    parser.add_argument("--excluded-classes", default="excluded_classes.json")

    args = parser.parse_args()
    return args


def load_checkpoint(checkpoint_directory):
    # Load checkpoint if exists
    checkpoint = os.path.join(checkpoint_directory, "last_checkpoint.pt")
    if os.path.exists(checkpoint):
        checkpoint = torch.load(checkpoint, map_location=torch.device("cpu"))
        model_state_dict = checkpoint["model"]
        optimizer_state_dict = checkpoint["optimizer"]
        lr_scheduler_state_dict = checkpoint["lr_scheduler"]
        training_step = checkpoint["training_step"]
        best_accuracy = checkpoint["best_accuracy"]
    else:
        model_state_dict = None
        optimizer_state_dict = None
        lr_scheduler_state_dict = None
        training_step = 0
        best_accuracy = 0
    return (
        model_state_dict,
        optimizer_state_dict,
        lr_scheduler_state_dict,
        training_step,
        best_accuracy,
    )


def setup_training(local_rank, checkpoint_directory, cpu, multi_gpu):
    if local_rank == 0 and not os.path.exists(checkpoint_directory):
        os.makedirs(checkpoint_directory)

    # Setup devices
    if cpu:
        device = torch.device("cpu")
    elif torch.cuda.is_available():
        device = torch.device(f"cuda:{local_rank}")
    else:
        print(
            "Tried to set GPU as torch.device, but torch.cuda.is_available() returned False.",
            " Running on CPU. Can use --cpu to specify program to run on CPU. Default behaviour is to run on GPU.",
        )
        device = torch.device("cpu")

    if multi_gpu:
        torch.cuda.set_device(device)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    return device


def train_epoch(
    model,
    train_data_iterator,
    val_data_iterator,
    training_step,
    best_accuracy,
    lr_scheduler,
    optimizer,
    training_steps,
    gradient_step_modulo,
    checkpoint_modulo,
    local_rank,
    sam,
    checkpoint_directory,
    multi_gpu,
):
    model.train()
    batch_count = 0
    total_loss = 0

    for samples, targets in train_data_iterator:
        if sam:
            enable_running_stats(model)
        output = model(samples)
        loss = F.cross_entropy(output, targets)
        loss /= gradient_step_modulo
        loss.backward()
        total_loss += loss.item()

        batch_count += 1
        if batch_count % gradient_step_modulo == 0:
            if sam:
                optimizer.first_step(zero_grad=True)
                disable_running_stats(model)
                F.cross_entropy(model(samples), targets).backward()
                optimizer.second_step(zero_grad=True)
            else:
                optimizer.step()
                model.zero_grad()
            lr_scheduler.step()

            training_step += 1
            if (training_step % checkpoint_modulo) == 0:
                val_loss, val_accuracy = val_epoch(model, val_data_iterator)

                model.train()

                if local_rank == 0:
                    print(
                        f"Step {training_step} - Train loss: {total_loss / batch_count},"
                        f" Val loss: {val_loss}, Val accuracy: {100 * val_accuracy:.2f}%"
                    )
                    best_accuracy = save_checkpoint(
                        model,
                        checkpoint_directory,
                        val_accuracy,
                        best_accuracy,
                        training_step,
                        multi_gpu,
                        lr_scheduler,
                        optimizer,
                    )

            if training_step >= training_steps:
                break

    model.zero_grad()

    return training_step, best_accuracy


def val_epoch(model, val_data_iterator):
    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for samples, targets in val_data_iterator:
            output = model(samples)
            loss = F.cross_entropy(output, targets, reduction="sum")
            total_loss += loss.item()
            prediction = output.argmax(dim=1)
            correct += prediction.eq(targets.view_as(prediction)).sum().item()
            total_samples += output.shape[0]

    return total_loss / total_samples, correct / total_samples


def get_checkpoint(
    model, best_accuracy, training_step, multi_gpu, lr_scheduler, optimizer
):
    checkpoint = {
        "optimizer": optimizer.state_dict(),
        "training_step": training_step,
        "best_accuracy": best_accuracy,
    }

    if multi_gpu:
        checkpoint["model"] = model.module.state_dict()
    else:
        checkpoint["model"] = model.state_dict()

    checkpoint["lr_scheduler"] = lr_scheduler.state_dict()

    return checkpoint


def save_checkpoint(
    model,
    checkpoint_directory,
    val_accuracy,
    best_accuracy,
    training_step,
    multi_gpu,
    lr_scheduler,
    optimizer,
):
    checkpoint = get_checkpoint(
        model, best_accuracy, training_step, multi_gpu, lr_scheduler, optimizer
    )
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        checkpoint["best_accuracy"] = val_accuracy
        torch.save(checkpoint, os.path.join(checkpoint_directory, "best_checkpoint.pt"))
    torch.save(
        checkpoint,
        os.path.join(checkpoint_directory, f"checkpoint_{training_step}.pt"),
    )
    torch.save(checkpoint, os.path.join(checkpoint_directory, "last_checkpoint.pt"))

    return best_accuracy


def train(
    model,
    train_data_iterator,
    val_data_iterator,
    training_steps,
    optimizer_state_dict,
    lr_scheduler_state_dict,
    training_step,
    best_accuracy,
    gradient_step_modulo,
    checkpoint_modulo,
    local_rank,
    sam,
    checkpoint_directory,
    multi_gpu,
):
    if sam:
        optimizer = SAM(
            model.parameters(), torch.optim.SGD, lr=0.1, momentum=0.9, weight_decay=1e-4
        )
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer.base_optimizer, step_size=150000
        )
    else:
        optimizer = torch.optim.SGD(
            model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4
        )
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150000)

    if optimizer_state_dict:
        optimizer.load_state_dict(
            optimizer_state_dict
        )  # Needs to be loaded after lr scheduler is initialized
    if lr_scheduler_state_dict:
        lr_scheduler.load_state_dict(lr_scheduler_state_dict)

    while training_step < training_steps:
        training_step, best_accuracy = train_epoch(
            model,
            train_data_iterator,
            val_data_iterator,
            training_step,
            best_accuracy,
            lr_scheduler,
            optimizer,
            training_steps,
            gradient_step_modulo,
            checkpoint_modulo,
            local_rank,
            sam,
            checkpoint_directory,
            multi_gpu,
        )


def main():
    args = parse_args()
    (
        model_state_dict,
        optimizer_state_dict,
        lr_scheduler_state_dict,
        training_step,
        best_accuracy,
    ) = load_checkpoint(args.checkpoint_directory)
    device = setup_training(
        args.local_rank,
        args.checkpoint_directory,
        args.cpu,
        args.multi_gpu,
    )

    with open(args.excluded_classes) as f:
        excluded_classes = list(json.load(f).values())
        classes = sorted(list(set(range(1000)) - set(excluded_classes)))

    model = ResNet(total_classes=len(classes), layers=args.layers).to(device)
    if model_state_dict is not None:
        model.load_state_dict(model_state_dict)

    train_data_iterator = ImageNetDataLoader(
        args.dataset_file,
        args.batch_size,
        sample_set="train",
        shuffle=True,
        num_workers=args.num_workers,
        device=device,
        classes=classes,
    )
    val_data_iterator = ImageNetDataLoader(
        args.dataset_file,
        args.batch_size,
        shuffle=False,
        sample_set="val",
        num_workers=args.num_workers,
        device=device,
        classes=classes,
    )
    if args.multi_gpu:
        if apex_available:
            model = parallel.DistributedDataParallel(model)
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank], output_device=args.local_rank
            )

    train(
        model,
        train_data_iterator,
        val_data_iterator,
        args.training_steps,
        optimizer_state_dict,
        lr_scheduler_state_dict,
        training_step,
        best_accuracy,
        args.gradient_step_modulo,
        args.checkpoint_modulo,
        args.local_rank,
        args.sam,
        args.checkpoint_directory,
        args.multi_gpu,
    )


if __name__ == "__main__":
    main()
