from dataset import dataloader
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from models import model
import torch.backends.cudnn as cudnn
import random
#定义超参数
use_cuda = True
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--EPOCH', type=int, default=1000)
parser.add_argument('--max_iter_k', type=int, default=20)
parser.add_argument('--max_iter_num', type=int, default=20000)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--error_threshold', type=float, default=0.0001)
parser.add_argument('--use_sgd', action='store_true',  default=True)
parser.add_argument('--manualSeed', type=int, help='manual seed')
opt = parser.parse_args()
print(opt)
# 以下为固定随机种子
if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)  # if you are using multi-GPU.
np.random.seed(opt.manualSeed)  # Numpy module.
torch.manual_seed(opt.manualSeed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
cudnn.benchmark = True

# 获得数据
train_loader, test_loaders = dataloader.get_dataloader(opt)



net = model.Conv_iter(error_threshold=opt.error_threshold, max_iter_num=opt.max_iter_num, use_cuda=use_cuda)
net_base = model.JOR_iter(error_threshold=opt.error_threshold, max_iter_num=opt.max_iter_num, use_cuda=use_cuda)
if use_cuda:
    net = net.cuda()
    net_base = net_base.cuda()

# 定义优化器，学习率调整策略，和损失计算策略
if opt.use_sgd:
    optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9)
else:
    optimizer = optim.Adam(net.parameters(), lr=opt.lr * 0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5, last_epoch=-1)
criterion = nn.MSELoss()

best_layer = [10.0 for i in range(len(test_loaders))]
best_ops = [10.0 for i in range(len(test_loaders))]
for epoch in range(opt.EPOCH):
    net.train()
    train_loss = 0.0
    train_max_error = 0.0
    scheduler.step()
    for y, b, G, iter_num, f in train_loader:
        k = np.random.randint(1, opt.max_iter_k + 1)
        y, b = y.unsqueeze(1), b.unsqueeze(1)
        G = G[0]
        x = torch.rand(y.size()) * G + (1 - G) * b
        if use_cuda:
            x, y, b, G = x.cuda(), y.cuda(), b.cuda(), G.cuda()
        pre = net(x, G, b, k)
        loss = criterion(y, pre)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * y.shape[0]
        #print(loss.item())
        error = (y.detach() - pre.detach()).abs().max().item()
        train_max_error = max(train_max_error, error)
    train_loss /= len(train_loader.dataset)
    print("Epoch:", epoch, "train loss:%.8f" % train_loss, "train max error:%.8f" % train_max_error, "manualSeed:", opt.manualSeed)

    net.eval()
    with torch.no_grad():
        for idx in range(len(test_loaders)):
            test_loader = test_loaders[idx]
            test_layer0 = 0.0
            test_ops0 = 0.0
            invalid = 0.0
            for y, b, G, iter_num, f in test_loader:
                y, b = y.unsqueeze(1), b.unsqueeze(1)
                if f.dim() != 1:
                    f = f.unsqueeze(1)
                G = G[0]
                x = torch.rand(y.size())
                if use_cuda:
                    x, y, b, G = x.cuda(), y.cuda(), b.cuda(), G.cuda()
                    if f is not None:
                        if f.dim() != 1:
                            f = f.cuda()
                        else:
                            f = None
                x = x * G + (1 - G) * b
                model_pre, model_iter_num = net.evaluation(x, G, b, f)
                base_pre, base_iter_num = net_base(x, G, b, f)
                if model_iter_num == opt.max_iter_num or base_iter_num == opt.max_iter_num:
                    invalid += 1
                    continue
                test_layer0 += model_iter_num * net.conv_num / base_iter_num
                test_ops0 += model_iter_num * (4 + (net.conv_num - 1) * 9) / (base_iter_num * 4)
            test_layer0 /= len(test_loader.dataset) - invalid
            test_ops0 /= len(test_loader.dataset) - invalid
            if test_layer0 < best_layer[idx]:
                best_layer[idx] = test_layer0
                best_ops[idx] = test_ops0
            print("     ", "test loader %d Layer:%.3f, Ops:%.3f, Best Layer:%.3f, Ops:%.3f, invalid:%d" % (
            idx + 1, test_layer0, test_ops0, best_layer[idx], best_ops[idx], invalid))
