import torch
import random
import argparse
import importlib
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from copy import deepcopy
from datasets import get_dataset
from models.init import get_conf_params
from torch.utils.data import DataLoader


def main(nepochs, net_name, architecture, dataset, batchnorm, init, fan, device, seed):
    print(f"Device: {device}")
    print("==> Preparing data..")
    trainset, testset, num_classes = get_dataset(dataset=dataset)

    print(f"==> Using dataset {dataset}")
    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)
    testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

    print(f"==> Using {net_name} ({architecture})") 
    module_ = importlib.import_module(f"models.{architecture}")
    try:
        net_class_ = getattr(module_, net_name)
    except AttributeError:
        raise AttributeError(f"Class {net_name} does not exists.")
    
    net = net_class_(num_classes=num_classes, batch_norm=batchnorm, conf_params=get_conf_params(init), fan=fan, hooks=False)
    print(net)
    net = net.to(device)
    num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    print("==> Total number of parameters to be trained", num_params)
    print(f"==> Attributes\n\tClasses = {num_classes}\n\tArchitecture = {architecture}")
    print(f"\tBatchnorm = {batchnorm}\n\tInit = {init}\n\tFan = {fan}")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=1e-2, weight_decay=0, momentum=0.9)
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=nepochs+5, verbose=True)

    print("==> Training the model..")
    best_model = {"state_dict": deepcopy(net.state_dict()), "metric": -np.inf}
    for epoch in range(nepochs):
        train_loss = 0.0
        correct, total = 0, 0
        losses = list()
        net.train()
        for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader), total=len(trainloader)):

            optimizer.zero_grad()

            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = net(inputs)

            loss = criterion(outputs, targets)
            if torch.isnan(loss).any():
                raise Exception("Loss is NaN")
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            acc = 100.0 * correct / total
            losses.append(train_loss / (batch_idx + 1))
        net.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc_t = correct / total
        if acc_t > best_model["metric"]:
            best_model["metric"] = acc_t
            best_model["state_dict"] = deepcopy(net.state_dict())
        lr_scheduler.step()
        print("Epoch {} out of {} | loss: {:.3f} | accuracy: {:.3f}".format(epoch + 1, nepochs, losses[-1], acc))
        print("Accuracy on test: {:.3f}".format(100.0 * correct / total))
    torch.save(best_model["state_dict"], f"/tmp/{net_name}_{architecture}_{init}_{fan}_{seed}.pth")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("net", type=str)
    parser.add_argument("--dataset", type=str, default="CIFAR10")
    parser.add_argument("--architecture", "-a", default="short_conv") 
    parser.add_argument("--batchnorm", "-bn", action="store_true")
    parser.add_argument("--brock", action="store_true")
    parser.add_argument("--fan_out", action="store_true")
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    # reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    init = "brock" if args.brock else "he"
    fan = "fan_out" if args.fan_out else "fan_in"
    device = "cpu" if args.cpu else "cuda"
    main(nepochs=300, net_name=args.net, architecture=args.architecture, dataset=args.dataset, batchnorm=args.batchnorm, init=init, fan=fan, device=device, seed=args.seed)
