import numpy as np
import torch 
import matplotlib.pyplot as plt
import torch.nn.functional as F

def ml_img():
    x = []
    y = []

    for i in range(40,160):
        for j in range(118, 122):
            x.append([i, j-int(i/4)])
            y.append([1])
        for j in range(158, 162):
            x.append([i, j+int(i/4)])
            y.append([1])
        # img[i, 118-int(i/4):122-int(i/4)] = 1
        # img[i, 158+int(i/4):162+int(i/4)] = 1
    
    for i in range(246,320):
        for j in range(258, 262):
            x.append([j+int(20*np.sin(i/15)), i-20])
            y.append([0])
            x.append([j+int(20*np.sin(i/15)), i-140])
            y.append([0])
        # img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-20] = 1
        # img[258+int(20*np.sin(i/15)):262+int(20*np.sin(i/15)), i-140] = 1

    # the inner lines in M
    for i in range(40,100):
        for j in range(88, 92):
            x.append([i, j+int(i/2)])
            y.append([1])
        for j in range(188, 192):
            x.append([i, j-int(i/2)])
            y.append([1])
        # img[i, 88+int(i/2):92+int(i/2)] = 1
        # img[i, 188-int(i/2):192-int(i/2)] = 1

    # vertical line in L
    p = 0
    for i in range(40,160):
        p+=1
        for j in range(246, 250):
            x.append([i, j])
            y.append([1])
        # img[i, 246:250] = 1
    # horizontal line in L
    for i in range(246,320):
        p+=1
        for j in range(158, 162):
            x.append([j, i])
            y.append([1])
        # img[158:162, i] = 1
    print(p)
    return np.array(x)*0.01, np.array(y)# img[:,50:350]

def create_ssgail_dataset():
    # Repeat the numbers from x1, x2, and y 50 times
    # n_columns = 300
    # a = np.repeat(np.arange(n_columns), n_columns).reshape(n_columns,n_columns)*0.01
    # x0 = a.reshape(n_columns,n_columns, 1)
    # x1 = a.T.reshape(n_columns,n_columns, 1)
    # xGrid = np.concatenate((x0, x1), axis=-1).reshape(-1,2)
    xGrid, yGrid = ml_img()
    # plt.imshow(yGrid, cmap='viridis')
    # plt.show()
    print(xGrid.shape, yGrid.shape)
    yGrid = yGrid.reshape(-1,1)

    xGrid = xGrid.astype(np.float32)
    yGrid  = yGrid.astype(np.float32)

    # Convert data to tensors
    X = torch.from_numpy(xGrid).clone().view(-1, 2)
    y_torch = torch.from_numpy(yGrid).clone().view(-1, 1)

    return X.cuda(), y_torch.cuda()

X_train, y_train = create_ssgail_dataset()

class gail(torch.nn.Module):
    def __init__(self, input_dim, output_dim, hid_dim = 256):
        super(gail, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hid_dim)
        self.linear2 = torch.nn.Linear(hid_dim, hid_dim)
        self.linear3 = torch.nn.Linear(hid_dim, hid_dim)
        self.linear4 = torch.nn.Linear(hid_dim, hid_dim)
        self.linear5 = torch.nn.Linear(hid_dim, hid_dim)
        self.linear6 = torch.nn.Linear(hid_dim, output_dim)

    def forward( self, x):
        x = self.linear1(F.relu(x))
        x = self.linear2(F.relu(x))
        x = self.linear3(F.relu(x))
        x = self.linear4(F.relu(x))
        x = self.linear5(F.relu(x))
        x = self.linear6(F.relu(x))
        x = torch.sigmoid(x)
        return x

# model_AND = LogisticRegresion(2,1)
# criterion = torch.nn.BCELoss()
model_AND = gail(2,1).cuda()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model_AND.parameters(), lr=0.01)

def train(model, criterion, optimizer, X, y, iter1):
    all_loss = []
    for epoch in range(iter1):
        y_hat = model(X)

        loss = criterion(y_hat, y)

        all_loss.append(loss.item())
        print(epoch, iter1, all_loss[-1])
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    return all_loss

n_columns = 300
a = np.repeat(np.arange(n_columns), n_columns).reshape(n_columns,n_columns)*0.01
x0 = a.reshape(n_columns,n_columns, 1)
x1 = a.T.reshape(n_columns,n_columns, 1)
xGrid = np.concatenate((x0, x1), axis=-1).reshape(-1,2).astype(np.float32)
X_test = torch.from_numpy(xGrid).clone().view(-1, 2).cuda()
for i in range(100):
    all_loss = train(model_AND, criterion, optimizer, X_train, y_train, 100)
    to_y = model_AND.forward(X_test).reshape(300,300)
    plt.imshow(to_y.cpu().detach().numpy(), cmap='viridis')
    #plt.show()
    plt.savefig('figs_gail/plot_'+str(i)+'.png')
