import torch
import numpy as np
import torch.nn as nn

from gen_rl.envs.paintGym.renderer.stroke_gen import draw
from gen_rl.envs.paintGym.renderer.model import FCN


def save():
    if use_cuda:
        net.cpu()
    torch.save(net.state_dict(), "data/renderer.pkl")
    if use_cuda:
        net.cuda()


if __name__ == "__main__":
    import torch.optim as optim

    import os

    os.mkdir("./data")

    criterion = nn.MSELoss()
    net = FCN()
    optimizer = optim.Adam(net.parameters(), lr=3e-6)
    batch_size = 64

    use_cuda = torch.cuda.is_available()
    step = 0

    save()
    while step < 500000:
        net.train()
        train_batch = []
        ground_truth = []
        for i in range(batch_size):
            f = np.random.uniform(0, 1, 10)
            train_batch.append(f)
            ground_truth.append(draw(f))

        train_batch = torch.tensor(train_batch).float()
        ground_truth = torch.tensor(ground_truth).float()
        if use_cuda:
            net = net.cuda()
            train_batch = train_batch.cuda()
            ground_truth = ground_truth.cuda()
        gen = net(train_batch)
        optimizer.zero_grad()
        loss = criterion(gen, ground_truth)
        loss.backward()
        optimizer.step()
        print(step, loss.item())
        if step < 200000:
            lr = 1e-4
        elif step < 400000:
            lr = 1e-5
        else:
            lr = 1e-6
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        print("train/loss", loss.item(), step)
        if step % 100 == 0:
            net.eval()
            gen = net(train_batch)
            loss = criterion(gen, ground_truth)
            print("val/loss", loss.item(), step)
            # for i in range(32):
            #     G = gen[i].cpu().data.numpy()
            #     GT = ground_truth[i].cpu().data.numpy()
            #     writer.add_image("train/gen{}.png".format(i), G, step)
            #     writer.add_image("train/ground_truth{}.png".format(i), GT, step)
        if step % 1000 == 0:
            save()
        step += 1
