import random

import numpy as np
import torch
import torchvision
import numpy
import time

# from apex import amp
from torch import nn
# import matplotlib
# import matplotlib.pyplot as plt
from torch.autograd import Variable

from model.ViT import ViT
from dataset import read_cifar10, read_cifar100, read_tiny
from logger import MetricsLogger
import os
import warmup_scheduler
import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

RANDOM_SEED = 1  # any random number


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed) 
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False  


def main():
    data_dir = './data'
    batch_size = 64
    n_epochs = 100
    Lr = 1e-3
    weight_delay = 5e-5
    num_classes = 200
    # layer_template = [[12, 16], [7, 12], [7, 16], [12, 12]]
    layer_template = [[12, 12, 384, 16], [12, 12, 384, 8], [6, 6, 192, 16], [6, 6, 192, 8]]

    data_loader_train, data_loader_test = read_tiny(batch_size, data_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for i in range(4):
        if i == 1:
            continue
        set_seed(RANDOM_SEED)
        if True:
            layer = layer_template[i]
            logger = MetricsLogger(log_file="TIN_ViT/ViT{:.0f},{:.0f}_Dim{:.0f}_Patch{:.0f}_LJLoss".format(layer[0], layer[1], layer[2], layer[3]))
            # model = ViT(num_classes=num_classes, head=layer[1], num_layers=layer[0]).to(device)
            model = ViT(num_classes=num_classes, head=layer[1], num_layers=layer[0], img_size=64, patch=layer[3], hidden=layer[2], mlp_hidden=layer[2]*4, is_cls_token=False).to(device)
            cost = nn.CrossEntropyLoss().to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=Lr,
                                         betas=(0.9, 0.999),
                                         weight_decay=weight_delay)
            base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200,
                                                                        eta_min=1e-5)
            scheduler = warmup_scheduler.GradualWarmupScheduler(optimizer, multiplier=1.,
                                                                total_epoch=5,
                                                                after_scheduler=base_scheduler)
            iter = 0
            model.train()
            since = time.time()
            for epoch in range(n_epochs):
                training_loss = 0.0
                discrete_loss = 0.0
                training_acc = 0.0
                print("Epoch {}/{}".format(epoch + 1, n_epochs))
                total_train = 0
                for i, data in tqdm.tqdm(enumerate(data_loader_train)):
                    iter = iter + 1
                    x, labels = data
                    x, labels = x.to(device), labels.to(device)
                    outputs, ljloss = model(x)
                    loss = cost(outputs, labels)
                    training_loss += loss.item()
                    discrete_loss += ljloss.item()
                    loss = ljloss + loss
                    _, pred = torch.max(outputs, 1) 
                    total_train += labels.size(0)
                    num_correct = (pred == labels).sum()
                    training_acc += num_correct.item()
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                train_acc = 100 * training_acc / total_train
                test_acc = eval(model, data_loader_test, device)
                model.train()
                logger.log_metrics_extra_loss(train_acc, test_acc, iter, training_loss / len(data_loader_train), discrete_loss / len(data_loader_train), epoch + 1)
                # logger.log_metrics(train_acc, test_acc, iter, training_loss / len(data_loader_train), epoch + 1)

                scheduler.step()

            time_used = time.time() - since
            logger.log_direct('Time:  {:.0f}m {:.0f}s'.format(time_used // 60, time_used % 60))
            logger.close()  
            del logger  


def eval(model, data_loader_test, device, half=False):
    model.eval()
    testing_correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader_test:
            x_test, label_test = data
            if half:
                x_test, label_test = x_test.to(device).half(), label_test.to(device).long()
            else:
                x_test, label_test = x_test.to(device), label_test.to(device)
            outputs, _ = model(x_test)
            _, pred = torch.max(outputs.data, 1)
            total += label_test.size(0)
            testing_correct += (pred == label_test).sum().item()
    return 100 * testing_correct / total


if __name__ == '__main__':
    main()
