import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from prelu_torch_modules import ShallowPreluNet
from load_mnist_datasets import load_mnist


class TwolayerReLU(nn.Module):
    def __init__(self, center=None, bias=None, input_dim=None, output_dim=None, hidden_dim=None, init_std=None):
        super(TwolayerReLU, self).__init__()
        self.center = nn.Parameter(center, requires_grad=False) if center is not None else nn.Parameter(
            torch.zeros([1, input_dim]), requires_grad=False)
        self.W = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.V = nn.Linear(hidden_dim, output_dim, bias=bias)
        if init_std is not None:
            nn.init.normal_(self.W.weight, std=init_std)
            nn.init.normal_(self.V.weight, std=init_std)

    def forward(self, xb, return_feature=False):
        xb = xb - self.center
        xb = self.W(xb)
        if return_feature:
            return self.V(xb), xb
        else:
            return self.V(xb)


if __name__ == '__main__':

    batch_size = 1000
    epochs = 50
    lr = 1e-1
    h = 50
    rep = 5

    digits = [0, 1, 2]

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    train_data_loader_batch, train_data_loader_full, _, test_data_loader_full, centers, n_train, n_test, input_dim = load_mnist(
        batch_size, subset=digits, n_class=3, normalize=False)

    model = TwolayerReLU(center=centers[0], bias=False, input_dim=input_dim, hidden_dim=h, output_dim=3,
                            init_std=1e-6).to(
        device).train()

    opt = optim.SGD(model.parameters(), lr=lr)

    for epoch in range(epochs):
        if epoch % 10 == 0:
            with torch.no_grad():
                xb, yb = next(iter(train_data_loader_full))
                pred = model(xb.to(device), return_feature=False)
                loss = F.cross_entropy(pred, yb.to(device))
                train_acc_val = (torch.argmax(model(xb.to(device), return_feature=False),
                                              1) == yb.to(device)).float().mean()
                xb, yb = next(iter(test_data_loader_full))
                valid_acc_val = (torch.argmax(model(xb.to(device), return_feature=False),
                                              1) == yb.to(device)).float().mean()
            print(epoch, loss.detach().cpu().numpy(), train_acc_val.detach().cpu().numpy(),
                  valid_acc_val.detach().cpu().numpy())

        for xb, yb in train_data_loader_batch:
            pred = model(xb.to(device), return_feature=False)
            loss = F.cross_entropy(pred, yb.to(device))

            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=50.0, norm_type=2)
            opt.step()
            opt.zero_grad()

        model.eval()
    with torch.no_grad():
        xb, yb = next(iter(train_data_loader_full))
        _, H = model(xb.to(device), return_feature=True)
        torch.save(
            dict(center=centers[0], x=xb.detach().cpu(), y=yb.detach().cpu(), H=H.detach().cpu(),
                 W=model.prelulayer.W.cpu(),
                 V=model.output_net.weight.cpu()), 'result.pt')
