import argparse
import numpy as np
import os
import torch
import model
import math
use_cuda = True


parser = argparse.ArgumentParser()
parser.add_argument('--grid_size', type=int, default=16)
parser.add_argument('--data_num', type=int, default=100000)
parser.add_argument('--max_iter_num', type=int, default=20000)
parser.add_argument('--error_threshold', type=float, default=0.00001)
opt = parser.parse_args()


JOC_solver = model.JOR_iter(error_threshold=opt.error_threshold, max_iter_num=opt.max_iter_num, use_cuda=use_cuda)



def set_boundary(x, bs):
    x[0, 0:-1] = bs[0]
    x[0:-1, -1] = bs[1]
    x[-1, 1:] = bs[2]
    x[1:, 0] = bs[3]
    return x

def set_boundary_L(x, bs):
    n = x.shape[0]
    n1 = int((n - 1) / 2)
    x[0, 0: n1+1] = bs[0]
    x[0: n1, n1+1] = bs[0]
    x[n1-1, n1+1:] = bs[1]
    x[n1:-1, -1] = bs[1]
    x[-1, 1:] = bs[2]
    x[1:, 0] = bs[3]
    return x

def set_boundary_C(x, bs):
    r1 = (x.shape[0] - 1) / 2
    r2 = r1 / 16
    cx, cy = x.shape[0] / 2.0, x.shape[0] / 2.0
    N = 1000
    for i in range(N):
        theta = i / N * 2 * np.pi
        x1 = int(np.sin(theta) * r1 + cx)
        y1 = int(np.cos(theta) * r1 + cy)
        x2 = int(np.sin(theta) * r2 + cx)
        y2 = int(np.cos(theta) * r2 + cy)
        if i < N / 2:
            x[x1][y1] = bs[0]
        else:
            x[x1][y1] = bs[1]
        x[x2][y2] = bs[2]
    return x

def set_G_C(x):
    r1 = int((x.shape[0] - 1) / 4)
    G0 = torch.zeros_like(x)
    G0 = set_boundary_C(G0, [1.0, 1.0, 1.0, 1.0])
    Q = []
    Q.append((r1, r1 * 2))
    x[r1][r1*2] = 1.0
    while len(Q) != 0:
        x1, y1 = Q[0]
        del Q[0]
        for _x, _y in [(-1, 0), (1, 0), (0, 1), (0, -1)]:
            new_x = x1 + _x
            new_y = y1 + _y
            if x[new_x][new_y] == 1.0: continue
            if G0[new_x][new_y] == 1.0: continue
            x[new_x][new_y] = 1.0
            Q.append((new_x, new_y))
    return x




def generation_train_data(opt):
    all_data = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_b = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_iter_num = torch.rand(opt.data_num)
    G = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0 + 1
    G = set_boundary(G, torch.rand(4) * 0)
    i = 0
    while i < opt.data_num:
        bs = torch.rand(4)
        all_data[i] = set_boundary(all_data[i], bs)
        b = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0
        b = set_boundary(b, bs)
        all_b[i] = b
        #print(all_data[i])
        if use_cuda:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0).cuda()
            G, b = G.cuda(), b.cuda()
        else:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0)
        ans, all_iter_num[i] = JOC_solver(U_0, G, b, opt.grid_size)
        if all_iter_num[i].item() == opt.max_iter_num:
            all_data[i] = torch.rand(opt.grid_size + 1, opt.grid_size + 1)
            print(i, "again")
            i -= 1
        else:
            all_data[i] = ans.squeeze().cpu()
            print(i, all_iter_num[i].item())
        i += 1
    np.save("train_data.npy", all_data.numpy())
    np.save("train_b.npy", all_b.numpy())
    np.save("train_iter_num.npy", all_iter_num.numpy())
    np.save("train_G", G.cpu().numpy())

def generation_test_data1(opt):
    all_data = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_b = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_iter_num = torch.rand(opt.data_num)
    G = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0 + 1
    G = set_boundary(G, torch.rand(4) * 0)
    i = 0
    while i < opt.data_num:
        bs = torch.rand(4)
        all_data[i] = set_boundary(all_data[i], bs)
        b = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0
        b = set_boundary(b, bs)
        all_b[i] = b
        #print(all_data[i])
        if use_cuda:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0).cuda()
            G, b = G.cuda(), b.cuda()
        else:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0)
        ans, all_iter_num[i] = JOC_solver(U_0, G, b, opt.grid_size)
        if all_iter_num[i].item() == opt.max_iter_num:
            all_data[i] = torch.rand(opt.grid_size + 1, opt.grid_size + 1)
            print(i, "again")
            i -= 1
        else:
            all_data[i] = ans.squeeze().cpu()
            print(i, all_iter_num[i].item())
        i += 1
    np.save("test_data1.npy", all_data.numpy())
    np.save("test_b1.npy", all_b.numpy())
    np.save("test_iter_num1.npy", all_iter_num.numpy())
    np.save("test_G1", G.cpu().numpy())


def generation_test_data2(opt):
    all_data = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_b = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_iter_num = torch.rand(opt.data_num)

    G = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0 + 1
    G = set_boundary(G, torch.rand(4) * 0)
    G[0:int(opt.grid_size / 2), int(opt.grid_size / 2) + 1:] = 0
    i = 0
    while i < opt.data_num:
        bs = torch.rand(6)
        all_data[i] = set_boundary_L(all_data[i] * G.cpu(), bs)
        b = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0
        b = set_boundary_L(b, bs)
        all_b[i] = b
        # print(b)
        # print(G)
        # print(all_data[i])
        if use_cuda:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0).cuda()
            G, b = G.cuda(), b.cuda()
        else:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0)
        ans, all_iter_num[i] = JOC_solver(U_0, G, b, opt.grid_size)
        if all_iter_num[i].item() == opt.max_iter_num:
            all_data[i] = torch.rand(opt.grid_size + 1, opt.grid_size + 1)
            print(i, "again")
            i -= 1
        else:
            all_data[i] = ans.squeeze().cpu()
            print(i, all_iter_num[i].item())
        i += 1
    np.save("test_data2.npy", all_data.numpy())
    np.save("test_b2.npy", all_b.numpy())
    np.save("test_iter_num2.npy", all_iter_num.numpy())
    np.save("test_G2", G.cpu().numpy())

def generation_test_data3(opt):
    all_data = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_b = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_iter_num = torch.rand(opt.data_num)

    G = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0
    G = set_G_C(G)
    i = 0
    while i < opt.data_num:
        bs = torch.rand(4)
        all_data[i] = set_boundary_C(all_data[i] * G.cpu(), bs)
        b = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0
        b = set_boundary_C(b, bs)
        all_b[i] = b
        # print(b)
        # print(G)
        # print(all_data[i])
        if use_cuda:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0).cuda()
            G, b = G.cuda(), b.cuda()
        else:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0)
        ans, all_iter_num[i] = JOC_solver(U_0, G, b, opt.grid_size)
        if all_iter_num[i].item() == opt.max_iter_num:
            all_data[i] = torch.rand(opt.grid_size + 1, opt.grid_size + 1)
            print(i, "again")
            i -= 1
        else:
            all_data[i] = ans.squeeze().cpu()
            print(i, all_iter_num[i].item())
        i += 1
    np.save("test_data3.npy", all_data.numpy())
    np.save("test_b3.npy", all_b.numpy())
    np.save("test_iter_num3.npy", all_iter_num.numpy())
    np.save("test_G3", G.cpu().numpy())

def generation_test_data4(opt):
    all_data = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_b = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_f = torch.rand(opt.data_num, opt.grid_size + 1, opt.grid_size + 1)
    all_iter_num = torch.rand(opt.data_num)

    G = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0 + 1
    G = set_boundary(G, torch.rand(4) * 0)

    i = 0
    while i < opt.data_num:
        bs = torch.rand(4)
        all_data[i] = set_boundary(all_data[i], bs)
        b = torch.rand(opt.grid_size + 1, opt.grid_size + 1) * 0
        b = set_boundary(b, bs)
        all_f[i] = set_boundary(all_f[i], torch.rand(4) * 0)
        all_b[i] = b
        if use_cuda:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0).cuda()
            G, b, f = G.cuda(), b.cuda(), all_f[i].cuda()
        else:
            U_0 = all_data[i].unsqueeze(0).unsqueeze(0)
            f = all_f[i]
        ans, all_iter_num[i] = JOC_solver(U_0, G, b, opt.grid_size, f)
        if all_iter_num[i].item() == opt.max_iter_num:
            all_data[i] = torch.rand(opt.grid_size + 1, opt.grid_size + 1)
            all_f[i] = torch.rand(opt.grid_size + 1, opt.grid_size + 1)
            print(i, "again")
            i -= 1
        else:
            all_data[i] = ans.squeeze().cpu()
            print(i, all_iter_num[i].item())
        i += 1
    np.save("test_data4.npy", all_data.numpy())
    np.save("test_b4.npy", all_b.numpy())
    np.save("test_iter_num4.npy", all_iter_num.numpy())
    np.save("test_G4", G.cpu().numpy())
    np.save("test_all_f4", all_f.numpy())


generation_train_data(opt)
opt.data_num = 20
opt.grid_size = 64
generation_test_data2(opt)
generation_test_data3(opt)
generation_test_data4(opt)
generation_test_data1(opt)
