import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


np.random.seed(1)
torch.manual_seed(1)
torch.random.manual_seed(1)
torch.cuda.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True

# hyperparameters
epoch_num = 10000
batch_num = 250
mb_size = 2
sub_num = 3
net_num = 3


def func_noise(x):
    return torch.randn(x.size()) * 0.01


def func(i,x):
    if i == 1:
        y = 2* torch.abs(x) - 2
    elif i == 2:
        y = 2 * torch.cos(3 * x)
    elif i == 3:
        y = 1.5 * torch.log(-x + 2.5) - 1
    else:
        y = 0
    return y + func_noise(x)


def get_batch():
    x = torch.rand(size=[batch_num*mb_size]) * 4 - 2
    mb_rand = torch.randint(0, sub_num,size=[batch_num, 1])
    mb_sub = mb_rand.repeat(1, mb_size).reshape(-1)

    y = torch.zeros_like(x)
    for i in range(sub_num):
        y += func(i+1, x) * (mb_sub==i)
    return x, y


class MainNet(torch.nn.Module):

    def __init__(self,num=sub_num):
        super(MainNet, self).__init__()
        self.net = nn.Sequential(
                       nn.Linear(1, 32),
                       nn.ReLU(),
                       nn.Linear(32, 32),
                       nn.ReLU(),
                       nn.Linear(32, 32),
                       nn.ReLU(),
                       nn.Linear(32, 1),
                      )

    def forward(self, x):
        x = x.view(-1, 1)
        return self.net(x)


if __name__ == '__main__':
    model_list = []
    opt_list = []
    
    for i in range(net_num):
        model = MainNet()
        model_list.append(model)
        opt_list.append(torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4))
        model.cuda()
    loss_func = torch.nn.MSELoss(reduction='none')

    batch_x, batch_y = get_batch()
    batch_x_cd, batch_y_cd = batch_x.cuda(), batch_y.cuda()
    batch_y_cd = batch_y_cd.view(-1, 1)

    for epoch in range(epoch_num):
        loss_stat = []

        compare_list = []
        for i in range(net_num):
            batch_pred = model_list[i](batch_x_cd)
            compare_list.append(loss_func(batch_pred,batch_y_cd).reshape(-1,mb_size).T.mean(dim=0))


        loss_stack = torch.stack(compare_list, dim=0)
        # loss_loc = torch.min(loss_stack, dim=0)[1]
        loss_bp = torch.min(loss_stack,dim=0)[0].mean()

        for i in range(net_num):
            opt_list[i].zero_grad()
        loss_bp.backward()
        for i in range(net_num):
            opt_list[i].step()

        loss_stat.append(loss_bp.cpu().detach().numpy())
        print(epoch)
        loss_stat = np.array(loss_stat)
        print(loss_stat.mean()*batch_num)

        if epoch % 4000 == 0 and epoch != 0:

            print("TEST")

            color = ['r', 'g', 'b']
            show_x = torch.linspace(-2, 2, 101)
            show_x_np = show_x.detach().numpy()
            show_y_list = []

            plt.figure(1)
            for i in range(net_num):
                show_y_np = model_list[i](show_x.cuda()).cpu().detach().numpy()
                show_y_list.append(show_y_np)
                plt.plot(show_x_np,show_y_np)
            plt.show()
