import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import tqdm
import numpy as np
import datetime
import os
import random
from cifar.model_cifar import ResNet18_cifar10

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def train_model(load=False, lr=0.1, num_epochs=100, weight_decay=5e-4,
                batch=128, opt='adam', use_scheduler=True):
    opt = opt.lower()
    print("The hyper_parameters are: optimizer={}, lr={}, num_epochs={}, weight_decay={}, batch={}".format(opt,lr, num_epochs, weight_decay, batch))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # change opt to all lower case
    if not os.path.exists('./pretrain'):
        os.makedirs('./pretrain')
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    # model = models.resnet18(num_classes=10).to(device)
    model = ResNet18_cifar10()
    model = model.to(device)
    if load:
        print("Loading pre-trained model...")
        model.load_state_dict(torch.load('./pretrain/resnet18_cifar10_{}.pth'.format(opt)))
    else:
        print("Training from scratch")
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2023, 0.1994, 0.2010])
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2023, 0.1994, 0.2010])
    ])
    #test if the location is available

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch, shuffle=True)
    # Load the CIFAR-10 test dataset
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    if opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    elif opt == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError("Invalid optimizer")
    if use_scheduler:
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=num_epochs,
                                                            steps_per_epoch=len(train_loader))
    else:
        scheduler = None

    best_acc = 0
    current_time = datetime.datetime.now().strftime("%m-%d %H:%M:%S")
    print("Starting Time:", current_time)
    bar = tqdm.tqdm(np.arange(num_epochs))
    for i in bar:
        l_r = []
        train_loss = []
        for _, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            l_r.append(get_lr(optimizer))
            if use_scheduler:
                scheduler.step()
            train_loss.append(loss.item())
        avg_loss = np.mean(train_loss)
        learning_rate = np.mean(l_r)
        # print(counter, "th training finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
        # Test the network
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, -1)
                predicted = predicted.detach().cpu()
                total += labels.size(0)
                correct += (predicted == labels.data).sum()
        accuracy = 100 * correct / total

        print("{:.2f}e-2 is the learning rate".format(learning_rate*100))
        # print(counter, "th test finished, time {datetime.timedelta(seconds=(time.time()-start_time))}")
        bar.set_description(
            "Loss: {:.4f}, Test Acc: {:.2f}%".format(avg_loss, accuracy))
        if i > (num_epochs *3/4):
            if accuracy > best_acc:
                best_acc = accuracy
                print("updating")
                torch.save(model.state_dict(), './pretrain/resnet18_cifar10_{}.pth'.format(opt))
    print(best_acc, "is the best accuracy so far")


if __name__ == '__main__':
    random_seed = 73
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    print("random seed: ", random_seed)
    # train_model(lr=0.01, opt='sgd')
    train_model(lr=0.01, opt='adam')
    # train_model(lr=0.05, opt='rmsprop')

