import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import random

from resnet18_32x32 import ResNet18_32x32



def train_step(model,features,labels, criterion, optimizer):
    
    model.train()
    
    optimizer.zero_grad()
    
    predictions = model(features)
    loss = criterion(predictions,labels)
    _, pred_labels = torch.max(predictions, 1)
    acc = (pred_labels == labels).float().mean()

    loss.backward()
    optimizer.step()

    return loss.item(),acc.item()

def valid_step(model,features,labels, criterion):
    
    # 预测模式，dropout层不发生作用
    model.eval()
    # 关闭梯度计算
    with torch.no_grad():
        predictions = model(features)
        loss = criterion(predictions,labels)
        _, pred_labels = torch.max(predictions, 1)
        acc = (pred_labels == labels).float().mean()
    
    return loss.item(), acc.item()



def train_model(model, criterion, dl_train, dl_valid,optimizer, device, num_epochs=200, scheduler=None):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    print("Start Training.............")
    model.to(device)
    for epoch in range(1,num_epochs+1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)

        loss_sum = 0.
        acc_sum = 0.

        for step, (features,labels) in enumerate(dl_train, 1):
            
            features, labels = features.to(device), labels.to(device)
            loss,acc = train_step(model,features,labels, criterion, optimizer)

            loss_sum += loss
            acc_sum += acc
            if step%100 == 0:   
                print(f"[step = {step}] loss: {loss_sum/step:.6f}, acc: {acc_sum/step:.4f}")
        
        val_loss_sum = 0.0
        val_acc_sum = 0.0
        val_step = 1

        for val_step, (features,labels) in enumerate(dl_valid, 1):

            features, labels = features.to(device), labels.to(device)
            val_loss,val_acc = valid_step(model,features,labels,criterion)

            val_loss_sum += val_loss
            val_acc_sum += val_acc
        
        print(f"\nEPOCH = {epoch}, loss = {loss_sum/step:.6f}, acc = {acc_sum/step:.4f}, \
            val_loss = {val_loss_sum/val_step:.6f}, val_acc = {val_acc_sum/val_step:.4f}") 

        if val_acc_sum/val_step > best_acc:
                best_acc = val_acc_sum/val_step
                best_model_wts = copy.deepcopy(model.state_dict())
        scheduler.step()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    # model.load_state_dict(best_model_wts)
    return best_model_wts

def main():

    SEED = 100
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    random.seed(SEED)


    os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    num_epochs = 200 
    batch_size = 128

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  
        transforms.RandomHorizontalFlip(),  
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])


    train_dataset = torchvision.datasets.CIFAR10(root='../datasets/',
                                                train=True, 
                                                transform=transform_train,
                                                download=True)

    test_dataset = torchvision.datasets.CIFAR10(root='../datasets/',
                                                train=False, 
                                                transform=transform_test,
                                                download=True)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=2)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=2)

    model = ResNet18_32x32()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


    best_model_wts = train_model(model, criterion, train_loader, test_loader, optimizer, device, num_epochs, scheduler)
    torch.save(best_model_wts, "./weights/resnet18_best.pth")


if __name__ == "__main__":

    main()



