from dataset import dataloader
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from models import model_my
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
#定义超参数
use_cuda = True
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--EPOCH', type=int, default=1000)
parser.add_argument('--max_iter_k', type=int, default=20)
parser.add_argument('--max_iter_num', type=int, default=200000)
parser.add_argument('--down_layer_num', type=int, default=2)
parser.add_argument('--smoothing_num', type=int, default=2)
parser.add_argument('--start_test_epoch', type=int, default=15)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--alpha', type=float, default=0.8)
parser.add_argument('--error_threshold', type=float, default=0.0001)
parser.add_argument('--use_sgd', action='store_true',  default=True)
parser.add_argument('--manualSeed', type=int, help='manual seed')
opt = parser.parse_args()
print(opt)
# 以下为固定随机种子
if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)  # if you are using multi-GPU.
np.random.seed(opt.manualSeed)  # Numpy module.
torch.manual_seed(opt.manualSeed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
cudnn.benchmark = True

# 获得数据
train_loader, test_loaders = dataloader.get_dataloader(opt)


net = model_my.Unet_iter(error_threshold=opt.error_threshold, max_iter_num=opt.max_iter_num,
                      down_layer_num=opt.down_layer_num, smoothing_num=opt.smoothing_num, use_cuda=use_cuda, alpha=opt.alpha)
net_base = model_my.Multi_grid_iter(error_threshold=opt.error_threshold, max_iter_num=opt.max_iter_num,
                                 down_layer_num=opt.down_layer_num, smoothing_num=opt.smoothing_num, use_cuda=use_cuda)
if use_cuda:
    net = net.cuda()
    net_base = net_base.cuda()

# 定义优化器，学习率调整策略，和损失计算策略
if opt.use_sgd:
    optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9)
else:
    optimizer = optim.Adam(net.parameters(), lr=opt.lr * 0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5, last_epoch=-1)
criterion = nn.MSELoss()

best_layer = [10.0 for i in range(len(test_loaders))]
best_ops = [10.0 for i in range(len(test_loaders))]
for epoch in range(opt.EPOCH):
    net.train()
    train_loss = 0.0
    train_max_error = 0.0
    scheduler.step()
    for y, b, b1, b2, G, G1, G2, f, f1, f2 in tqdm(train_loader):
        k = np.random.randint(1, opt.max_iter_k + 1)
        y, b, b1, b2, f, f1, f2 = y.unsqueeze(1), b.unsqueeze(1), b1.unsqueeze(1), b2.unsqueeze(1), f.unsqueeze(1), \
                                  f1.unsqueeze(1), f2.unsqueeze(1)
        G, G1, G2 = G[0], G1[0], G2[0]
        x = torch.rand(y.size()) * G + (1 - G) * b
        if use_cuda:
            x, y, b, b1, b2, f, f1, f2, G, G1, G2 = x.cuda(), y.cuda(), b.cuda(), b1.cuda(), b2.cuda(), f.cuda(), \
                                                    f1.cuda(), f2.cuda(), G.cuda(), G1.cuda(), G2.cuda()
        bs = [b, b1, b2]
        Gs = [G, G1, G2]
        fs = [f, f1, f2]
        pre = net(x, Gs, bs, fs, k)
        loss = criterion(y, pre)
        optimizer.zero_grad()
        loss.backward()
       # print(loss.item())
        optimizer.step()
        train_loss += loss.item() * y.shape[0]
        #print(loss.item())
        error = (y.detach() - pre.detach()).abs().max().item()
        train_max_error = max(train_max_error, error)
    train_loss /= len(train_loader.dataset)
    print("Epoch:", epoch, "train loss:%.8f" % train_loss, "train max error:%.8f" % train_max_error, "manualSeed:", opt.manualSeed)
    if epoch < opt.start_test_epoch:
        continue
    net.eval()
    with torch.no_grad():
        for idx in range(len(test_loaders)):
            test_loader = test_loaders[idx]
            test_layer0 = 0.0
            test_ops0 = 0.0
            invalid = 0.0
            for y, b, b1, b2, G, G1, G2, f, f1, f2 in test_loader:
                y, b, b1, b2, f, f1, f2 = y.unsqueeze(1), b.unsqueeze(1), b1.unsqueeze(1), b2.unsqueeze(1), f.unsqueeze(1), \
                                          f1.unsqueeze(1), f2.unsqueeze(1)
                G, G1, G2 = G[0], G1[0], G2[0]
                x = torch.rand(y.size()) * G + (1 - G) * b
                if use_cuda:
                    x, y, b, b1, b2, f, f1, f2, G, G1, G2 = x.cuda(), y.cuda(), b.cuda(), b1.cuda(), b2.cuda(), f.cuda(), \
                                                            f1.cuda(), f2.cuda(), G.cuda(), G1.cuda(), G2.cuda()
                bs = [b, b1, b2]
                Gs = [G, G1, G2]
                fs = [f, f1, f2]
                model_pre, model_iter_num = net.evaluation(x, Gs, bs, fs)
                base_pre, base_iter_num = net_base(x, Gs, bs, fs)
            #    print(idx, model_iter_num, base_iter_num)
                if model_iter_num == opt.max_iter_num or base_iter_num == opt.max_iter_num:
                    invalid += 1
                    continue
                test_layer0 += model_iter_num / base_iter_num
                test_ops0 += model_iter_num * 9.0 / 4.0 / base_iter_num
            test_layer0 /= max(len(test_loader.dataset) - invalid, 1)
            test_ops0 /= max(len(test_loader.dataset) - invalid, 1)
            if test_layer0 < best_layer[idx]:
                best_layer[idx] = test_layer0
                best_ops[idx] = test_ops0
            print("     ", "test loader %d Layer:%.3f, Ops:%.3f, Best Layer:%.3f, Ops:%.3f, invalid:%d" % (
                idx + 1, test_layer0, test_ops0, best_layer[idx], best_ops[idx], invalid))





