import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


np.random.seed(1)
torch.manual_seed(1)
torch.random.manual_seed(1)
torch.cuda.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True


# hyperparameters
epoch_num = 10000
batch_num = 250
mb_size = 2
sub_num = 3
net_num = 4


def func_noise(x):
    return torch.randn(x.size()) * 0.01


def func(i,x):
    if i == 1:
        y = 2 * torch.abs(x) - 2
    elif i == 2:
        y = 2 * torch.cos(3 * x)
    elif i == 3:
        y = 1.5 * torch.log(-x + 2.5) - 1
    else:
        y = 0
    return y + func_noise(x)


def get_batch():
    x = torch.rand(size=[batch_num*mb_size]) * 4 - 2
    mb_rand = torch.randint(0, sub_num, size=[batch_num, 1])
    mb_sub = mb_rand.repeat(1, mb_size).reshape(-1)

    y = torch.zeros_like(x)
    for i in range(sub_num):
        y += func(i+1,x)*(mb_sub==i)
    return x, y


class MainNet(torch.nn.Module):
    def __init__(self,num=sub_num):
        super(MainNet, self).__init__()
        self.net = nn.Sequential(
                      nn.Linear(1, 32),
                      nn.ReLU(),
                      nn.Linear(32, 32),
                      nn.ReLU(),
                      nn.Linear(32, 32),
                      nn.ReLU(),
                      nn.Linear(32, 1),
                   )
    def forward(self, x):
        x = x.view(-1,1)
        return self.net(x)


model_list = []
opt_list = []
for i in range(net_num):
    model = MainNet()
    model_list.append(model)
    opt_list.append(torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4))
    model.cuda()
loss_func = torch.nn.MSELoss(reduction='none')

batch_x, batch_y = get_batch()
batch_x_cd, batch_y_cd = batch_x.cuda(), batch_y.cuda()
batch_y_cd = batch_y_cd.view(-1, 1)

for epoch in range(epoch_num):
    loss_stat = []

    compare_list = []
    for i in range(net_num):
        batch_pred = model_list[i](batch_x_cd)
        compare_list.append(loss_func(batch_pred, batch_y_cd).reshape(-1, mb_size).T.mean(dim=0))


    loss_stack = torch.stack(compare_list, dim=0)
    # loss_loc = torch.min(loss_stack, dim=0)[1]
    loss_bp = torch.min(loss_stack, dim=0)[0].mean()

    for i in range(net_num):
        opt_list[i].zero_grad()
    loss_bp.backward()
    for i in range(net_num):
        opt_list[i].step()

    loss_stat.append(loss_bp.cpu().detach().numpy())
    print(epoch)
    loss_stat = np.array(loss_stat)
    print(loss_stat.mean()*batch_num)

    if epoch % 4000 == 0 and epoch != 0:

        print("TEST")

        color = ['r', 'g', 'b']
        show_x = torch.linspace(-2, 2, 101)
        show_x_np = show_x.detach().numpy()
        show_y_list = []

        plt.figure(1)
        for i in range(net_num):
            show_y_np=model_list[i](show_x.cuda()).cpu().detach().numpy()
            show_y_list.append(show_y_np)
            plt.plot(show_x_np,show_y_np)
        plt.show()
