import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from derivatives import derivatives
import argparse

step = 70
input_dim = 1
width = 256
train_length = 512
criterion = torch.nn.MSELoss()
x = np.linspace(0,10,train_length)

sig1 =  3*np.sin(2*np.pi*0.2*x)
sig2 = 5*np.sin(2*np.pi*0.1*x)
sig = sig1 + sig2

sigTorch = torch.from_numpy(sig).type(torch.FloatTensor)
coordinates = torch.from_numpy(np.linspace(0, 10, train_length)).type(torch.FloatTensor)

train_labels = sigTorch[::step].reshape(-1,1).type(torch.FloatTensor).cuda()
train_coords = coordinates[::step].reshape(-1,1).cuda()
scaled_train_coords = train_coords / 10

class gau(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return (-0.5*(x**2)/(0.05**2)).exp()

def pretrain(net):
    print('Pretraining:')
    sig =  np.sin(2*np.pi*3.*x)

    sigTorch = torch.from_numpy(sig).type(torch.FloatTensor).reshape(-1,1).cuda()
    coordinates = torch.from_numpy(np.linspace(0, 10, train_length)/ 10).type(torch.FloatTensor).reshape(-1,1).cuda()

    optimizer = torch.optim.Adam(net.parameters(), 1e-4)

    for epoch in range(3000000):
        net.train()
        optimizer.zero_grad()
        out = net(coordinates)
        loss1 = criterion(out, sigTorch)
        loss = loss1
        loss.backward()
        if epoch % 300000 == 0:
            print("loss")
            print(loss1.item())
        optimizer.step()


def main(bandwidth, activation):

    class MLP(nn.Module):
        def __init__(self,input_dim, width):
            super(MLP, self).__init__()
            self.fc1 = nn.Linear(input_dim, width)
            if activation == 'gauss':
                self.act1 = gau()
            elif activation == 'relu':
                self.act1 = nn.ReLU()
            self.fc2 = nn.Linear(width, width)
            if activation == 'gauss':
                self.act2 = gau()
            elif activation == 'relu':
                self.act2 = nn.ReLU()
            self.fc3 = nn.Linear(width, width)
            if activation == 'gauss':
                self.act3 = gau()
            elif activation == 'relu':
                self.act3 = nn.ReLU()
            self.fc4 = nn.Linear(width, width)
            if activation == 'gauss':
                self.act4 = gau()
            elif activation == 'relu':
                self.act4 = nn.ReLU()
            self.fco = nn.Linear(width, 1)
            if activation == 'gauss' and bandwidth == 'low':
                self.init_weights()

        def init_weights(self):
            with torch.no_grad():
                b = 0.03
                self.fc1.weight.normal_(0,b)
                self.fc2.weight.normal_(0,b)
                self.fc3.weight.normal_(0,b)
                self.fc4.weight.normal_(0,b)
                self.fco.weight.normal_(0,b)

        def features(self, x):
            x1 = self.fc1(x)
            e1 = self.act1(x1)
            x2 = self.fc2(e1)
            e2 = self.act2(x2)
            x3 = self.fc3(e2)
            e3 = self.act3(x3)
            x4 = self.fc4(e3)
            e4 = self.act4(x4)
            return e1, e2, e3, e4

        def forward(self, x):
            x1 = self.fc1(x)
            e1 = self.act1(x1)
            x2 = self.fc2(e1)
            e2 = self.act2(x2)
            x3 = self.fc3(e2)
            e3 = self.act3(x3)
            x4 = self.fc4(e3)
            e4 = self.act4(x4)
            out_ = self.fco(e4)
            return out_

        def name(self):
            return "MLP"

    hess = []
    jac1 = []
    jac2 = []
    weight = []
    features = []

    for i in range(10):
        print(i)
        net = MLP(input_dim, width).cuda()

        optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, momentum=0.9)
        
        if activation == 'relu' and bandwidth == 'high':
            pretrain(net)

        if activation == 'relu':
            epochs = 100000
        elif activation == 'gauss':
            epochs = 10000
        for epoch in range(epochs):
            net.train()
            optimizer.zero_grad()
            out  = net(scaled_train_coords)
            loss1 = criterion(out, train_labels)
            loss = loss1
            loss.backward()
            optimizer.step()

        torch.save(net.state_dict(), f'{activation}_{i}_{bandwidth}.pt')

        d = derivatives(net, criterion, [8, 1], [8, 1], 'cuda:0')
        d.update((scaled_train_coords, train_labels))

        h = d.power('H')
        hess.append(h)
        jac1.append(d.power('jac1train'))
        jac2.append(d.power('jac2train'))
        weight.append(torch.linalg.matrix_norm(net.fc1.weight, ord = 2).detach().cpu().numpy())

        feature_norms = []
        with torch.no_grad():
            for i in range(4):
                feature_norms.append(torch.linalg.matrix_norm(net.features(scaled_train_coords)[i], ord = 2).cpu().numpy()**2)

        features.append(feature_norms)

    hess = np.array(hess)
    jac1 = np.array(jac1)
    jac2 = np.array(jac2)
    weight = np.array(weight)
    features = np.array(features)

    np.save(f'hess_{activation}_{bandwidth}_Adam.npy', hess)
    np.save(f'jac1_{activation}_{bandwidth}_Adam.npy', jac1)
    np.save(f'jac2_{activation}_{bandwidth}_Adam.npy', jac2)
    np.save(f'feat_{activation}_{bandwidth}_Adam.npy', features)
    np.save(f'weight_{activation}_{bandwidth}_Adam.npy', weight)

    net.eval()
    out = net(coordinates.reshape(-1,1).cuda()/10).reshape(-1,1)
    fig = plt.figure()
    plt.xlabel('$x$', fontsize=20)
    plt.ylabel('$y$', fontsize=20)
    plt.xticks(fontsize=15)
    plt.yticks(fontsize=15)

    plt.plot(coordinates.detach().cpu().numpy(), out.detach().cpu().numpy())
    plt.scatter(train_coords.detach().cpu().numpy(),train_labels.detach().cpu().numpy())

    plt.savefig(f'{activation}_{bandwidth}_interpolate_Adampretrain.png', format='png', dpi=300, bbox_inches='tight')
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train using gradient descent.")
    parser.add_argument('bandwidth', type = str, choices = ['high', 'low'])
    parser.add_argument('activation', type = str, choices = ['gauss', 'relu'])
    args = parser.parse_args()

    main(bandwidth = args.bandwidth, activation=args.activation)