import random

import numpy as np
import torch
import torchvision
import numpy
import time

from torch import nn
# import matplotlib
# import matplotlib.pyplot as plt
from torch.autograd import Variable

# from model.ResNetSE import ResNet
from model.ResNet import *
# from transformer_resnet import ResNet20
from dataset import read_cifar10, read_cifar100, read_tiny
from logger import MetricsLogger
import os
import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

RANDOM_SEED = 1  # any random number


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)  
    torch.cuda.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    os.environ['PYTHONHASHSEED'] = str(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main():
    data_dir = './data'
    batch_size = 64
    n_epochs = 100
    Lr = 0.1
    momentum = 0.9
    weight_delay = 1e-4
    num_classes = 200
    # layer_template = [20, 32, 44, 56, 110]
    layer_template = [resnet18(num_classes=num_classes), resnet34(num_classes=num_classes), resnet50(num_classes=num_classes), resnet101(num_classes=num_classes)]
    args = ['18', '34', '50', '101']

    data_loader_train, data_loader_test = read_tiny(batch_size, data_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for i in range(4):
        i=i+3
        set_seed(RANDOM_SEED)
        if True:
            layer = layer_template[i]
            logger = MetricsLogger(log_file="TIN_ResNet/ResNet{}_LJLoss.log".format(args[i]))
            # model = ResNet(num_classes=num_classes, layers=layer).to(device)
            model = layer.cuda()
            cost = nn.CrossEntropyLoss().to(device) 

            optimizer = torch.optim.SGD(model.parameters(), Lr,
                                        momentum=momentum,
                                        weight_decay=weight_delay)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[41, 61, 81], gamma=0.1,
                                                             last_epoch=-1)
            iter = 0
            model.train()
            since = time.time()
            for epoch in range(n_epochs):
                training_loss = 0.0
                discrete_loss = 0.0
                training_acc = 0.0
                print("Epoch {}/{}".format(epoch + 1, n_epochs))
                total_train = 0
                for i, data in tqdm.tqdm(enumerate(data_loader_train)):
                    iter = iter + 1
                    x, labels = data
                    x, labels = x.to(device), labels.to(device)
                    outputs, ljloss = model(x)
                    loss = cost(outputs, labels)
                    training_loss += loss.item()
                    discrete_loss += ljloss.item()
                    loss = ljloss + loss
                    _, pred = torch.max(outputs, 1) 
                    total_train += labels.size(0)
                    num_correct = (pred == labels).sum()
                    training_acc += num_correct.item()
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                train_acc = 100 * training_acc / total_train
                test_acc = eval(model, data_loader_test, device)
                model.train()
                logger.log_metrics_extra_loss(train_acc, test_acc, iter, training_loss / len(data_loader_train), discrete_loss / len(data_loader_train), epoch + 1)

            time_used = time.time() - since
            logger.log_direct('Time: {:.0f}m {:.0f}s'.format(time_used // 60, time_used % 60))
            logger.close()  
            del logger 


def eval(model, data_loader_test, device, half=False):
    model.eval()
    testing_correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader_test:
            x_test, label_test = data
            if half:
                x_test, label_test = x_test.to(device).half(), label_test.to(device).long()
            else:
                x_test, label_test = x_test.to(device), label_test.to(device)
            outputs, _ = model(x_test)
            _, pred = torch.max(outputs.data, 1)
            total += label_test.size(0)
            testing_correct += (pred == label_test).sum().item()
    return 100 * testing_correct / total


if __name__ == '__main__':
    main()
