import math

import torch
import torch.nn as nn

from resonance import sampling


class ParallelLinear:

    def __init__(self, num_models):
        self.num_models = num_models
        self.weight = torch.Tensor(num_models)
        self.bias = torch.Tensor(num_models)

        self.velocity_weight = torch.Tensor(num_models)
        self.velocity_bias = torch.Tensor(num_models)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.weight, -1.0, 1.0)
        nn.init.uniform_(self.bias, -1.0, 1.0)

        nn.init.zeros_(self.velocity_weight)
        nn.init.zeros_(self.velocity_bias)

    def __call__(self, input):
        return input * self.weight + self.bias

    def update(self, learning_rate, momentum, inputs, targets):
        batch_size = inputs.shape[0]
        predictions = self(inputs)

        weight_grad = -2 * torch.sum((targets - predictions) * inputs, dim=0) / batch_size
        self.velocity_weight = momentum * self.velocity_weight - learning_rate * weight_grad
        self.weight = self.weight + self.velocity_weight

        bias_grad = -2 * torch.sum(targets - predictions, dim=0) / batch_size
        self.velocity_bias = momentum * self.velocity_bias - learning_rate * bias_grad
        self.bias = self.bias + self.velocity_bias


class ParallelLinearADAM:

    def __init__(self, num_models):
        self.num_models = num_models
        self.params = {
            'weight': torch.Tensor(num_models),
            'bias': torch.Tensor(num_models),
        }

        # XXX: To shoehorn this in as a quick test, we fix beta_2 and vary beta_1 as we
        # did momentum in the SGDm version.
        self.beta_2 = 0.999
        self.epsilon = 1e-8

        self.m_t = {}
        self.v_t = {}

        self.time = 0

        self.reset_parameters()

    def reset_parameters(self):
        self.time = 0

        for name in ['weight', 'bias']:
            nn.init.uniform_(self.params[name], -1.0, 1.0)

            self.m_t[name] = torch.zeros_like(self.params[name])
            self.v_t[name] = torch.zeros_like(self.params[name])

    def __call__(self, input):
        return input * self.params['weight'] + self.params['bias']

    def update_param(self, name, grad, step_size, beta_1):
        self.m_t[name] = beta_1 * self.m_t[name] + (1 - beta_1) * grad
        self.v_t[name] = self.beta_2 * self.v_t[name] + (1 - self.beta_2) * torch.square(grad)

        m_corrected = self.m_t[name] / (1 - beta_1 ** self.time)
        v_corrected = self.v_t[name] / (1 - self.beta_2 ** self.time)

        self.params[name] = self.params[name] - step_size * m_corrected / (torch.sqrt(v_corrected) + self.epsilon)

    def update(self, learning_rate, beta_1, inputs, targets):
        self.time += 1

        batch_size = inputs.shape[0]
        predictions = self(inputs)

        weight_grad = -2 * torch.sum((targets - predictions) * inputs, dim=0) / batch_size
        self.update_param('weight', weight_grad, learning_rate, beta_1)

        bias_grad = -2 * torch.sum(targets - predictions, dim=0) / batch_size
        self.update_param('bias', bias_grad, learning_rate, beta_1)


def vectorized_loss(predictions, targets):
    return torch.norm(predictions - targets, p='fro', dim=0) / math.sqrt(predictions.shape[0])


def vectorized_distance_to_target(target_functions, learned_functions):
    weight_square = (target_functions.weight - learned_functions.weight) ** 2
    bias_square = (target_functions.bias - learned_functions.bias) ** 2
    
    return torch.sqrt(weight_square + bias_square)


def test_parallel_linear():
    # Test with batch size 3, and 5 models in parallel.
    test_models = ParallelLinear(5)
    test_input = torch.tensor([
        [1.0, 2.0, 3.0, 4.0, 5.0],
        [1.1, 2.1, 3.1, 4.1, 5.1],
        [1.2, 2.2, 3.2, 4.2, 5.2],
    ])
    test_output = test_models(test_input)
    assert test_output.shape == test_input.shape

    # Each column should have been operated on by a single model of the form y = wx + b
    # so test for consistency within each batch element (row)
    w_1 = (test_output[2] * test_input[1] - test_output[1] * test_input[2]) / (test_input[1] - test_input[2])
    w_2 = (test_output[2] * test_input[0] - test_output[0] * test_input[2]) / (test_input[0] - test_input[2])

    assert torch.allclose(w_1, w_2, 1e-03, 1e-05)

    b_1 = (test_output[0] - test_output[2]) / (test_input[0] - test_input[2])
    b_2 = (test_output[1] - test_output[2]) / (test_input[1] - test_input[2])

    assert torch.allclose(b_1, b_2, 1e-03, 1e-05)


def test_vectorized_loss():
    frequencies = [0.1, 0.2]
    batch_size = 1000
    runs_per_frequency = 3
    amplitude = 1
    variance = 1.0

    samplers = [sampling.SinusoidalMeanGaussian(amplitude, frequency, variance) for frequency in frequencies]

    # Should get one loss per run.
    samples_2p5 = sampling.generate_domain_samples(2.5, samplers, batch_size, runs_per_frequency)
    test_samples = sampling.generate_test_domain_samples(samplers, batch_size, runs_per_frequency)
    test_losses = vectorized_loss(test_samples, samples_2p5)

    assert test_losses.shape == torch.Size([runs_per_frequency * len(frequencies)])

    # Larger batch size should not scale loss.
    many_samples_2p5 = sampling.generate_domain_samples(2.5, samplers, batch_size * 10, runs_per_frequency)
    many_test_samples = sampling.generate_test_domain_samples(samplers, batch_size * 10, runs_per_frequency)
    closer_test_losses = vectorized_loss(test_samples, samples_2p5)

    assert torch.allclose(test_losses, closer_test_losses, 1e-1, 1e-3)

if __name__ == '__main__':
    test_vectorized_loss()
    test_parallel_linear()
