import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from resnet18_32x32 import ResNet18_32x32



def train_step(model,features,labels, criterion, optimizer):
    
    # 训练模式，dropout层发生作用
    model.train()
    
    # 梯度清零
    optimizer.zero_grad()
    
    # 正向传播求损失
    predictions = model(features)
    loss = criterion(predictions,labels)
    _, pred_labels = torch.max(predictions, 1)
    acc = (pred_labels == labels).float().mean()

    # 反向传播求梯度
    loss.backward()
    optimizer.step()

    return loss.item(),acc.item()

def valid_step(model,features,labels, criterion):
    
    # 预测模式，dropout层不发生作用
    model.eval()
    # 关闭梯度计算
    with torch.no_grad():
        predictions = model(features)
        loss = criterion(predictions,labels)
        _, pred_labels = torch.max(predictions, 1)
        acc = (pred_labels == labels).float().mean()
    
    return loss.item(), acc.item()



def train_model(model, criterion, dl_train, dl_valid,optimizer, device, num_epochs=200, log_step_freq=100):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    print("Start Training.............")
    model.to(device)
    for epoch in range(1,num_epochs+1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)

        loss_sum = 0.
        acc_sum = 0.

        for step, (features,labels) in enumerate(dl_train, 1):
            
            features, labels = features.to(device), labels.to(device)
            loss,acc = train_step(model,features,labels, criterion, optimizer)

            # 打印batch级别日志
            loss_sum += loss
            acc_sum += acc
            if step%log_step_freq == 0:   
                print(f"[step = {step}] loss: {loss_sum/step:.6f}, acc: {acc_sum/step:.4f}")

        # scheduler.step()
        
        val_loss_sum = 0.0
        val_acc_sum = 0.0
        val_step = 1

        for val_step, (features,labels) in enumerate(dl_valid, 1):

            features, labels = features.to(device), labels.to(device)
            val_loss,val_acc = valid_step(model,features,labels,criterion)

            val_loss_sum += val_loss
            val_acc_sum += val_acc
        
        print(f"\nEPOCH = {epoch}, loss = {loss_sum/step:.6f}, acc = {acc_sum/step:.4f}, \
            val_loss = {val_loss_sum/val_step:.6f}, val_acc = {val_acc_sum/val_step:.4f}") 

        if val_acc_sum/val_step > best_acc:
                best_acc = val_acc_sum/val_step
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    # model.load_state_dict(best_model_wts)
    return best_model_wts

def main():

    SEED = 100
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)

    device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    num_epochs = 50 #50轮
    batch_size = 128
    learning_rate = 0.01 #学习率0.01


    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  #先四周填充0，在吧图像随机裁剪成32*32
        transforms.RandomHorizontalFlip(),  #图像一半的概率翻转，一半的概率不翻转
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor()
    ])


    
    # CIFAR-10 数据集下载
    train_dataset = torchvision.datasets.CIFAR10(root='../datasets/',
                                                train=True, 
                                                transform=transform_train,
                                                download=True)

    test_dataset = torchvision.datasets.CIFAR10(root='../datasets/',
                                                train=False, 
                                                transform=transform_test,
                                                download=True)

    # 数据载入
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=2)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=2)

    model = ResNet18_32x32()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    # 每7个epochs衰减LR通过设置gamma=0.1
    # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    best_model_wts = train_model(model, criterion, train_loader, test_loader, optimizer, device, num_epochs)
    torch.save(best_model_wts, "./weights/raw.pth")


if __name__ == "__main__":

    main()



