import numpy as np
import torch
import torch.nn.functional as F

device = 'cuda'


# Returns a random target function constructed with k components and an m dimensional input, also returns projection U
def gen_rand_sine(rng, m, k, n_comps=5):
    U_unnormed = torch.randn(m, k).to(device, dtype=torch.float32)
    U = U_unnormed / torch.sqrt(torch.sum(U_unnormed ** 2, dim=0, keepdim=True))  # Normalize columns

    a = rng.normal(size=(n_comps, 1, k))  # Random amplitudes
    omega = 2 * np.pi * rng.normal(size=(n_comps, 1, k))  # Random frequencies
    phi = rng.uniform(low=-np.pi, high=np.pi, size=(n_comps, 1, k))  # Random phases
    a = torch.from_numpy(a).to(device, dtype=torch.float64)
    omega = torch.from_numpy(omega).to(device, dtype=torch.float64)
    phi = torch.from_numpy(phi).to(device, dtype=torch.float64)

    # Takes input with shape (n, m), output has shape (n, k); sum over last axis to get a scalar
    def rand_sine(X):
        # X has shape (n, m)
        # U has shape (m, k)
        # Compute distances between X and U
        distances = torch.sqrt(torch.sum((torch.unsqueeze(X, 2) - torch.unsqueeze(U, 0)) ** 2, dim=1))  # Shape (n, k)

        Y = torch.sum(a * torch.sin(omega * torch.unsqueeze(distances, 0) + phi), dim=0)  # Shape (n, k)
        return Y

    return rand_sine, U


# Computes output of modular architecture with input x
def run_modules(x, modules):
    output = None
    for params in modules:
        z = x
        for layer in range(len(params)):
            if layer == 0:
                linear_layer = params[layer]
                # Extract weights and biases
                W = linear_layer.weight  # Shape (out_features, in_features)
                b = linear_layer.bias  # Shape (out_features, )
                # z has shape (n, in_features)
                # Compute distances between z and W
                distances = torch.sqrt(torch.sum((torch.unsqueeze(z, 1) -
                                                  torch.unsqueeze(W, 0)) ** 2, dim=2))  # Shape (n, out_features)

                # Add bias and apply relu
                z = F.relu(distances + torch.unsqueeze(b, 0))  # Shape (n, out_features)
            else:
                linear_layer = params[layer]
                z = linear_layer(z)
                if layer != len(params) - 1:
                    z = F.relu(z)
        if output is None:
            output = z
        else:
            output += z
    return torch.sum(output, dim=1) / np.sqrt(len(modules))  # Normalize based on number of modules


# Find Y^T K^{-1} Y in a differentiable way
def solve_linsys(K, Y):
    K = K.to(dtype=torch.float64)
    Y = Y.to(dtype=torch.float64)
    beta = torch.squeeze(torch.linalg.solve(K, Y.unsqueeze(1)), dim=1)
    reconstructed_Y = torch.einsum('ij,j->i', K, beta)

    val = torch.sum(beta * reconstructed_Y)
    grad = -torch.sum(torch.einsum('ij,j->i', K, beta.detach()) * beta.detach())
    return grad - grad.detach() + val.detach()


# Try to learn a single module
def learn_one_module(X, Y, iters=2000, lr=0.001, batch_size=1024,
                     sigma=0.5, eps=1e-5):
    m = X.shape[1]
    n_pts = X.shape[0]
    u = torch.nn.Parameter(torch.randn(m).to(device, dtype=torch.float64))

    # Change dtype
    X = X.to(dtype=torch.float64)
    Y = Y.to(dtype=torch.float64)
    optimizer = torch.optim.Adam([u], lr=lr)

    for i in range(iters):
        print('Module learning iteration ' + str(i))

        idx = torch.randperm(n_pts)[:batch_size]
        X_samp = X[idx, :]  # Shape (batch_size, m)
        assert not torch.isnan(u).any()
        # u has shape (m, )
        # Compute distances between X_samp and u
        dist_u = torch.sqrt(torch.sum((X_samp - u.unsqueeze(0)) ** 2, dim=1))  # Shape (batch_size, )

        K = torch.exp(-(dist_u.unsqueeze(0) - dist_u.unsqueeze(1)) ** 2 / 2 / sigma ** 2)  # Kernel

        assert not torch.isnan(K).any()
        K = K + eps * torch.eye(K.shape[0]).to(device, dtype=torch.float64)  # For numerical stability

        Y_samp = Y[idx]

        # Solve Y = K beta
        error = solve_linsys(K, Y_samp)  # K has shape (batch_size, batch_size)
        assert error > 0

        total_error = error

        total_error.backward()
        print('Module learning error: ' + str(total_error))
        assert not torch.isnan(u.grad).any()

        optimizer.step()
        optimizer.zero_grad()

    return u.detach().to(dtype=torch.float32)


# Trains modular/monolithic network with random initialization given n training samples, m input dims and k components
def train_baseline(n, target_func, N_modules, m, k, n_layers, width, n_test=10000, iters=10000,
                   batch_size=100, lr=0.001, bottleneck=True, wandb_log=False):
    # Generate data
    x_train = torch.normal(torch.zeros(n, m)).to(device, dtype=torch.float32)
    x_test = torch.normal(torch.zeros(n_test, m)).to(device, dtype=torch.float32)
    y_train = torch.sum(target_func(x_train), dim=1) / np.sqrt(k)  # (n, ), normalize based on k
    y_test = torch.sum(target_func(x_test), dim=1) / np.sqrt(k)  # (n, ), normalize based on k

    modules = []
    for _ in range(N_modules):
        last_width = m
        params = []
        if bottleneck:
            params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
            last_width = 1
            for _ in range(n_layers - 1):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        else:
            for _ in range(n_layers):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
        modules.append(params)

    # Set up optimizer
    optimizer = torch.optim.Adam([p for params in modules for layer in params for p in layer.parameters()], lr=lr)

    for iteration in range(iters):
        idx = torch.randperm(n)[:min(batch_size, n)]
        print('Iteration: ' + str(iteration))
        # Run prediction networks
        y_train_hat = run_modules(x_train[idx], modules)  # Shape (n, )
        y_test_hat = run_modules(x_test, modules)  # Shape (n, )

        loss_train = torch.mean((y_train[idx] - y_train_hat) ** 2)  # Shape (n, )
        loss_test = torch.mean((y_test - y_test_hat) ** 2)  # Shape (n, )
        loss_random = torch.mean(y_test ** 2)  # Shape (n, )

        loss_train.backward()
        print('Loss train: ' + str(loss_train))
        print('Loss test: ' + str(loss_test))
        print('Loss random: ' + str(loss_random))

        optimizer.step()
        optimizer.zero_grad()
    return loss_test.detach().cpu(), loss_train.detach().cpu()


# Trains modular network with initialization for first layer found using a kernel-based method
def train_ours(n, target_func, N_modules, m, k, n_layers, width, n_test=10000, iters=10000,
               batch_size=100, lr=0.001, sigma=1.0, module_batch_size=1024, module_lr=0.01, module_iters=200,
               bottleneck=True, wandb_log=False):
    # Generate data
    x_train = torch.normal(torch.zeros(n, m)).to(device, dtype=torch.float32)
    x_test = torch.normal(torch.zeros(n_test, m)).to(device, dtype=torch.float32)
    y_train = torch.sum(target_func(x_train), dim=1) / np.sqrt(k)  # (n, ), normalize based on k
    y_test = torch.sum(target_func(x_test), dim=1) / np.sqrt(k)  # (n, ), normalize based on k

    modules = []
    for _ in range(N_modules):
        last_width = m
        params = []
        if bottleneck:
            params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
            last_width = 1
            for _ in range(n_layers - 1):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        else:
            for _ in range(n_layers):
                params.append(torch.nn.Linear(last_width, width).to(device, dtype=torch.float32))
                last_width = width
        params.append(torch.nn.Linear(last_width, 1).to(device, dtype=torch.float32))
        modules.append(params)

    # Set up optimizer
    optimizer = torch.optim.Adam([p for params in modules for layer in params for p in layer.parameters()], lr=lr)

    # Initialize modules
    us = []
    if not bottleneck:
        N_modules = width
    for i in range(N_modules):
        print('Learning module ' + str(i) + '...')
        u = learn_one_module(x_train, y_train, sigma=sigma, lr=module_lr, iters=module_iters,
                             batch_size=module_batch_size)
        us.append(u)
    U_pred = torch.stack(us, dim=1)  # (m, N_modules)

    with torch.no_grad():
        if bottleneck:
            for i in range(N_modules):
                modules[i][0].weight.copy_(U_pred[:, i])
        else:
            assert N_modules == width
            modules[0][0].weight.copy_(U_pred.T)

    for iteration in range(iters):
        idx = torch.randperm(n)[:min(batch_size, n)]
        print('Iteration: ' + str(iteration))
        # Run prediction networks
        y_train_hat = run_modules(x_train[idx], modules)  # Shape (n, )
        y_test_hat = run_modules(x_test, modules)  # Shape (n, )

        loss_train = torch.mean((y_train[idx] - y_train_hat) ** 2)
        loss_test = torch.mean((y_test - y_test_hat) ** 2)
        loss_random = torch.mean(y_test ** 2)

        loss_train.backward()
        print('Loss train: ' + str(loss_train))
        print('Loss test: ' + str(loss_test))
        print('Loss random: ' + str(loss_random))

        optimizer.step()
        optimizer.zero_grad()
    return loss_test.detach().cpu(), loss_train.detach().cpu()
