import numpy as np
import torch
import torch.nn.functional as F

device = 'cuda'


# Returns a random target function constructed with k components and an m dimensional input, also returns projection U
def gen_rand_sine(rng, m, k, n_comps=5):
    U_unnormed = torch.randn(m, k).to(device, dtype=torch.float32)
    U = U_unnormed / torch.sqrt(torch.sum(U_unnormed ** 2, dim=0, keepdim=True))  # Normalize columns

    a = rng.normal(size=(n_comps, 1, k))  # Random amplitudes
    omega = 2 * np.pi * rng.normal(size=(n_comps, 1, k))  # Random frequencies
    phi = rng.uniform(low=-np.pi, high=np.pi, size=(n_comps, 1, k))  # Random phases
    a = torch.from_numpy(a).to(device, dtype=torch.float64)
    omega = torch.from_numpy(omega).to(device, dtype=torch.float64)
    phi = torch.from_numpy(phi).to(device, dtype=torch.float64)

    # Takes input with shape (n, m), output has shape (n, k); sum over last axis to get a scalar
    def rand_sine(X):
        # X has shape (n, m)
        Z = torch.einsum('ij, jk->ik', X, U)  # Shape (n, k)
        Y = torch.sum(a * torch.sin(omega * torch.unsqueeze(Z, 0) + phi), dim=0)  # Shape (n, k)
        return Y

    return rand_sine, U


# Computes output of modular architecture with input x
def run_modules(x, modules):
    output = None
    for params in modules:
        z = x
        for layer in range(len(params)):
            linear_layer = params[layer]
            z = linear_layer(z)
            if layer != len(params) - 1:
                z = F.relu(z)
        if output is None:
            output = z
        else:
            output += z
    return torch.sum(output, dim=1) / np.sqrt(len(modules))  # Normalize based on number of modules


# Find Y^T K^{-1} Y in a differentiable way
def solve_linsys(K, Y):
    K = K.to(dtype=torch.float64)
    Y = Y.to(dtype=torch.float64)
    beta = torch.squeeze(torch.linalg.solve(K, Y.unsqueeze(1)), dim=1)
    reconstructed_Y = torch.einsum('ij,j->i', K, beta)

    val = torch.sum(beta * reconstructed_Y)
    grad = -torch.sum(torch.einsum('ij,j->i', K, beta.detach()) * beta.detach())
    return grad - grad.detach() + val.detach()


# Try to learn a single module
def learn_one_module(X, Y, iters=2000, lr=0.001, batch_size=1024, lam=1000.0, threshold=0.1,
                     sigma=0.5, eps=1e-5):
    m = X.shape[1]
    n_pts = X.shape[0]
    v_unnormed = torch.nn.Parameter(torch.randn(m).to(device, dtype=torch.float64))
    u_unnormed = torch.nn.Parameter(torch.randn(m).to(device, dtype=torch.float64))

    # Change dtype
    X = X.to(dtype=torch.float64)
    Y = Y.to(dtype=torch.float64)
    optimizer = torch.optim.Adam([v_unnormed, u_unnormed], lr=lr)

    for i in range(iters):
        print('Module learning iteration ' + str(i))
        v = v_unnormed / torch.sqrt(torch.sum(v_unnormed ** 2, dim=0))
        u = u_unnormed / torch.sqrt(torch.sum(u_unnormed ** 2, dim=0))

        idx = torch.randperm(n_pts)[:batch_size]
        X_samp = X[idx, :]
        assert not torch.isnan(v).any()
        assert not torch.isnan(u).any()
        proj_v = torch.einsum('ij, j->i', X_samp, u) / torch.sum(v * u)  # Shape (batch_size, ), project along v

        Xu = X_samp - torch.einsum('i,j->ij', proj_v, v)  # Shape (batch_size, k)
        K = torch.exp(-(proj_v.unsqueeze(0) - proj_v.unsqueeze(1)) ** 2 / 2 / sigma ** 2) \
            + torch.exp(-torch.sum((Xu.unsqueeze(0) - Xu.unsqueeze(1)) ** 2, dim=2) / 2 / sigma ** 2)  # Kernel
        assert not torch.isnan(K).any()
        K = K + eps * torch.eye(K.shape[0]).to(device, dtype=torch.float64)  # For numerical stability

        Y_samp = Y[idx]

        # Solve Y = K beta
        error = solve_linsys(K, Y_samp)  # K has shape (batch_size, batch_size)
        assert error > 0

        total_error = error + lam * error.detach() * \
                      torch.nn.functional.relu(torch.tensor(threshold).to(device, dtype=torch.float64) -
                                               torch.abs(torch.sum(u * v)))

        total_error.backward()
        print('Module learning error: ' + str(total_error))
        assert not torch.isnan(v_unnormed.grad).any()
        assert not torch.isnan(u_unnormed.grad).any()

        optimizer.step()
        optimizer.zero_grad()

    return u.detach().to(dtype=torch.float32), v.detach().to(dtype=torch.float32)


# Trains modular/monolithic network with random initialization given n training samples, m input dims and k components
def train_baseline(n, target_func, N_modules, m, k, n_layers, width, n_test=10000, iters=10000,
                   batch_size=100, lr=0.001, bottleneck=True):
    # Generate data
    x_train = torch.normal(torch.zeros(n, m)).to(device, dtype=torch.float32)
    x_test = torch.normal(torch.zeros(n_test, m)).to(device, dtype=torch.float32)
    y_train = torch.sum(target_func(x_train), dim=1) / np.sqrt(k)  # (n, ), normalize based on k
    y_test = torch.sum(target_func(x_test), dim=1) / np.sqrt(k)  # (n, ), normalize based on k

    modules = []
    for _ in range(N_modules):
        last_width = m
        params = []
        if bottleneck:
            params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
            last_width = 1
            for _ in range(n_layers - 1):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        else:
            for _ in range(n_layers):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
        modules.append(params)

    # Set up optimizer
    optimizer = torch.optim.Adam([p for params in modules for layer in params for p in layer.parameters()], lr=lr)

    for iteration in range(iters):
        idx = torch.randperm(n)[:min(batch_size, n)]
        print('Iteration: ' + str(iteration))
        # Run prediction networks
        y_train_hat = run_modules(x_train[idx], modules)  # Shape (n, )
        y_test_hat = run_modules(x_test, modules)  # Shape (n, )

        loss_train = torch.mean((y_train[idx] - y_train_hat) ** 2)  # Shape (n, )
        loss_test = torch.mean((y_test - y_test_hat) ** 2)  # Shape (n, )
        loss_random = torch.mean(y_test ** 2)  # Shape (n, )

        loss_train.backward()
        print('Loss train: ' + str(loss_train))
        print('Loss test: ' + str(loss_test))
        print('Loss random: ' + str(loss_random))

        optimizer.step()
        optimizer.zero_grad()
    return loss_test.detach().cpu(), loss_train.detach().cpu()


# Trains modular network with initialization for first layer found using a kernel-based method
def train_ours(n, target_func, N_modules, m, k, n_layers, width, n_test=10000, iters=10000,
               batch_size=100, lr=0.001, sigma=1.0, module_batch_size=1024, module_lr=0.01, module_iters=200,
               bottleneck=True):
    # Generate data
    x_train = torch.normal(torch.zeros(n, m)).to(device, dtype=torch.float32)
    x_test = torch.normal(torch.zeros(n_test, m)).to(device, dtype=torch.float32)
    y_train = torch.sum(target_func(x_train), dim=1) / np.sqrt(k)  # (n, ), normalize based on k
    y_test = torch.sum(target_func(x_test), dim=1) / np.sqrt(k)  # (n, ), normalize based on k

    modules = []
    for _ in range(N_modules):
        last_width = m
        params = []
        if bottleneck:
            params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
            last_width = 1
            for _ in range(n_layers - 1):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        else:
            for _ in range(n_layers):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
        modules.append(params)

    # Set up optimizer
    optimizer = torch.optim.Adam([p for params in modules for layer in params for p in layer.parameters()], lr=lr)

    # Initialize modules
    us = []
    vs = []
    if not bottleneck:
        N_modules = width
    for i in range(N_modules):
        print('Learning module ' + str(i) + '...')
        u, v = learn_one_module(x_train, y_train, sigma=sigma, lr=module_lr, iters=module_iters,
                                batch_size=module_batch_size)
        us.append(u)
        vs.append(v)
    U_pred = torch.stack(us, dim=1)  # (m, N_modules)

    with torch.no_grad():
        if bottleneck:
            for i in range(N_modules):
                modules[i][0].weight.copy_(U_pred[:, i])
        else:
            assert N_modules == width
            modules[0][0].weight.copy_(U_pred.T)

    for iteration in range(iters):
        idx = torch.randperm(n)[:min(batch_size, n)]
        print('Iteration: ' + str(iteration))
        # Run prediction networks
        y_train_hat = run_modules(x_train[idx], modules)  # Shape (n, )
        y_test_hat = run_modules(x_test, modules)  # Shape (n, )

        loss_train = torch.mean((y_train[idx] - y_train_hat) ** 2)
        loss_test = torch.mean((y_test - y_test_hat) ** 2)
        loss_random = torch.mean(y_test ** 2)

        loss_train.backward()
        print('Loss train: ' + str(loss_train))
        print('Loss test: ' + str(loss_test))
        print('Loss random: ' + str(loss_random))

        optimizer.step()
        optimizer.zero_grad()
    return loss_test.detach().cpu(), loss_train.detach().cpu()


# Set module params to values (except first layer of each module)
def set_params(modules, values):
    with torch.no_grad():
        for i in range(len(modules)):
            for j in range(1, len(modules[i])):
                modules[i][j].weight.copy_(values[i][j][0])
                modules[i][j].bias.copy_(values[i][j][1])


# Returns n_samp to achieve desired error rate
def binary_search(train, desired_err, *args, search_iters=18, **kwargs):
    min_logsamp = 0
    cur_logsamp = 12
    max_logsamp = None

    for iter in range(search_iters):
        print('Binary search iteration ' + str(iter))
        print('Cur samples ' + str(int(2 ** cur_logsamp)))
        err, train_err = train(int(2 ** cur_logsamp), *args, **kwargs)
        print('Error ' + str(err))

        if err > desired_err:  # Increase samples
            min_logsamp = cur_logsamp
            if max_logsamp is None:
                cur_logsamp = cur_logsamp + 2
            else:
                cur_logsamp = (cur_logsamp + max_logsamp) / 2
        else:  # Decrease samples
            max_logsamp = cur_logsamp
            cur_logsamp = (cur_logsamp + min_logsamp) / 2

        if cur_logsamp >= 22:  # if it reaches 26, cuda cannot support it anymore
            return None

        if max_logsamp and max_logsamp - min_logsamp < 0.3:  # not None and accurate enough
            return 2 ** cur_logsamp

    return 2 ** cur_logsamp
