import torch.nn.functional as F
import matplotlib.pyplot as plt
from timeit import default_timer
from utils.utilities3 import *
from utils.adam import Adam
from utils.params import get_args
from model_dict import get_model
import math
import os

torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

args = get_args()
TRAIN_PATH = os.path.join(args.data_path, './darcy_rough_train.mat')
TEST_PATH = os.path.join(args.data_path, './darcy_rough_test.mat')

ntrain = 1000
ntest = 100
in_channels = 1
out_channels = 1
r1 = args.h_down
r2 = args.w_down
s1 = int(((args.h - 1) / r1) + 1)
s2 = int(((args.w - 1) / r2) + 1)

batch_size = args.batch_size
learning_rate = args.learning_rate
epochs = args.epochs
step_size = args.step_size
gamma = args.gamma

model_save_path = args.model_save_path
model_save_name = args.model_save_name


model = get_model(args)
print(count_params(model))

### forward problem
reader = MatReader(TRAIN_PATH)
x_train = reader.read_field('coeff')[:ntrain, ::r1, ::r2][:, :s1, :s2]
y_train = reader.read_field('sol')[:ntrain, ::r1, ::r2][:, :s1, :s2]

reader.load_file(TEST_PATH)
x_test = reader.read_field('coeff')[:ntest, ::r1, ::r2][:, :s1, :s2]
y_test = reader.read_field('sol')[:ntest, ::r1, ::r2][:, :s1, :s2]


### inverse problem
# reader = MatReader(TRAIN_PATH)
# y_train = reader.read_field('coeff')[:ntrain, ::r1, ::r2][:, :s1, :s2]
# x_train = reader.read_field('sol')[:ntrain, ::r1, ::r2][:, :s1, :s2]

# reader.load_file(TEST_PATH)
# y_test = reader.read_field('coeff')[:ntest, ::r1, ::r2][:, :s1, :s2]
# x_test = reader.read_field('sol')[:ntest, ::r1, ::r2][:, :s1, :s2]

# noise = 0.1
# x_normalizer = UnitGaussianNormalizer(x_train)
# x_train = x_normalizer.encode(x_train)
# xnoise = noise * torch.randn(*x_train.shape, dtype=torch.float32)
# x_train += xnoise
# x_test = x_normalizer.encode(x_test)
# xnoise = noise * torch.randn(*x_test.shape, dtype=torch.float32)
# x_test += xnoise

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)
y_normalizer.cuda()

x_train = x_train.reshape(ntrain, s1, s2, 1)
x_test = x_test.reshape(ntest, s1, s2, 1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
                                          shuffle=False)

################################################################
# training and evaluation
################################################################
optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

myloss = LpLoss(size_average=False)

for ep in range(epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()
        optimizer.zero_grad()
        out = model(x).reshape(batch_size, s1, s2)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)

        loss = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
        loss.backward()

        optimizer.step()
        train_l2 += loss.item()

    scheduler.step()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()

            out = model(x).reshape(batch_size, s1, s2)
            out = y_normalizer.decode(out)

            test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

    train_l2 /= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print(ep, t2 - t1, train_l2, test_l2)
    if ep % step_size == 0:
        if not os.path.exists(model_save_path):
            os.makedirs(model_save_path)
        print('save model')
        torch.save(model.state_dict(), os.path.join(model_save_path, model_save_name))
