import numpy as np
import torch
import torch.nn.functional as F

device = 'cuda'


def gen_rand_sine(rng, m, k, n_comps=5):
    U_unnormed = torch.randn(m, k).to(device, dtype=torch.float32)
    U = U_unnormed / torch.sqrt(torch.sum(U_unnormed ** 2, dim=0, keepdim=True))  # Normalize columns

    a = rng.normal(size=(n_comps, 1, k))
    omega = 2 * np.pi * rng.normal(size=(n_comps, 1, k))
    phi = rng.uniform(low=-np.pi, high=np.pi, size=(n_comps, 1, k))
    a = torch.from_numpy(a).to(device, dtype=torch.float64)
    omega = torch.from_numpy(omega).to(device, dtype=torch.float64)
    phi = torch.from_numpy(phi).to(device, dtype=torch.float64)

    def rand_sine(X):
        # X has shape (?, m)
        Z = torch.einsum('ij, jk->ik', X, U)  # Shape (?, k)
        Y = torch.sum(a * torch.sin(omega * torch.unsqueeze(Z, 0) + phi), dim=0)  # Shape (?, k)
        return Y

    return rand_sine, U


def train(target_func, n, m, k, n_layers, width, n_test=10000, iters=10000,
          batch_size=100, lr=0.001):
    # Generate data
    x_train = torch.normal(torch.zeros(n, m)).to(device, dtype=torch.float32)
    x_test = torch.normal(torch.zeros(n_test, m)).to(device, dtype=torch.float32)
    y_train = target_func(x_train)  # (n, k)
    y_test = target_func(x_test)  # (n, k)

    last_width = m
    params = []
    for _ in range(n_layers):
        params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
        last_width = width
    params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))

    # Set up optimizer
    optimizer = torch.optim.Adam([p for layer in params for p in layer.parameters()], lr=lr)

    for iteration in range(iters):
        idx = torch.randperm(n)[:min(batch_size, n)]
        print('Iteration: ' + str(iteration))
        # Run prediction networks
        z_train = x_train[idx]
        z_test = x_test
        for layer in range(len(params)):
            linear_layer = params[layer]
            z_train = linear_layer(z_train)
            z_test = linear_layer(z_test)
            if layer != len(params) - 1:
                z_train = F.relu(z_train)
                z_test = F.relu(z_test)
        y_train_hat = z_train
        y_test_hat = z_test

        loss_train = torch.mean((torch.sum(y_train[idx], dim=1) / np.sqrt(k) -
                                 torch.sum(y_train_hat, dim=1)) ** 2)
        loss_test = torch.mean((torch.sum(y_test, dim=1) / np.sqrt(k) -
                                torch.sum(y_test_hat, dim=1)) ** 2)
        loss_random = torch.mean((torch.sum(y_test, dim=1) / np.sqrt(k)) ** 2)

        loss_train.backward()
        print('Loss train: ' + str(loss_train))
        print('Loss test: ' + str(loss_test))
        print('Loss random: ' + str(loss_random))

        optimizer.step()
        optimizer.zero_grad()
    return loss_test.detach().cpu(), loss_train.detach().cpu()
