"""
Implementation of GD construction of https://arxiv.org/pdf/2212.07677.pdf
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from noisy_lr_data_generation import NoisyLinearRegressionTask

class GDBaseline():

    def __init__(self, lr):
        self.lr = lr

    # x_context has shape (batch_size, sequence_length, point_dim)
    # y_context has shape (batch_size, sequence_length, 1)
    # x_query has shape (batch_size, point_dim)
    # does 1 step of GD using x_context and y_context
    # result is of shape (batch_size,)
    def evaluate(self, x_context, y_context, x_query):
        result = x_context * y_context # (batch_size, sequence_length, point_dim)
        result = torch.sum(result, dim=1) # (batch_size, point_dim)
        batch_size = result.size()[0]
        point_dim = result.size()[1]
        result = result.view(batch_size, 1, point_dim)
        x_query = x_query.view(batch_size, point_dim, 1)
        result = torch.matmul(result, x_query) # (batch_size, 1, 1)
        result = torch.squeeze(result, dim=1)
        result = torch.squeeze(result, dim=1)
        result = self.lr * result
        return result

###### The following functions are to use sampling to approximate
###### the learning rate that is predicted by the theory.

"""
The numerator is E[sum_i y_i w_hat^top x_i].
Here, since w is sampled from N(0, 1), and
the output variance is sigma^2, w_hat = (XX^T + sigma^2 * I)^(-1) X^T * y.
"""
def sample_numerator(train_config, sample_size):
    # Note that task samples an independent weight vector for each batch element.
    # Furthermore, the noise added to the outputs is independent for each batch element.
    if torch.cuda.is_available():
        device=torch.device("cuda")
    else:
        device=torch.device("cpu")
    task = NoisyLinearRegressionTask(dimension=train_config.dimension, output_variance=train_config.output_variance,
                                     batch_size=sample_size, device=device)
    
    # Sample xs, ys, and take out the last one
    xs = task.sample_xs(train_config.sequence_length) # (batch_size, n_points + 1, dimension)
    _, _, ys = task.evaluate(xs) # (batch_size, n_points + 1, 1)
    xs = xs[:, :-1, :] # (batch_size, n_points, dimension)
    ys = ys[:, :-1, :] # (batch_size, n_points, 1)

    # Compute sum_i (x_i y_i)
    sum_over_points = xs * ys # (batch_size, n_points, dimension)
    sum_over_points = torch.sum(sum_over_points, dim=1) # (batch_size, dimension)
    sum_over_points = torch.unsqueeze(sum_over_points, 1) # (batch_size, 1, dimension)

    # Compute ridge regression vector
    ridge_w = torch.bmm(torch.transpose(xs, 1, 2), xs) # (batch_size, dimension, dimension)
    ridge_w = ridge_w + train_config.output_variance * torch.eye(train_config.dimension, device=device) # (batch_size, dimension, dimension)
    ridge_w = torch.linalg.inv(ridge_w) # (batch_size, dimension, dimension)
    ridge_w = torch.bmm(ridge_w, torch.transpose(xs, 1, 2)) # (batch_size, dimension, n_points)
    ridge_w = torch.bmm(ridge_w, ys) # (batch_size, dimension, 1)

    # Compute dot product
    dot_products = torch.bmm(sum_over_points, ridge_w) # (batch_size, 1, 1)
    dot_products = torch.squeeze(dot_products, 1) # (batch_size, 1)
    dot_products = torch.squeeze(dot_products, 1) # (batch_size,)
    return torch.mean(dot_products).item()

"""
The denominator is E[(sum_i y_i x_i)^T (sum_i y_i x_i)].
"""
def sample_denominator(train_config, sample_size):
    task = NoisyLinearRegressionTask(dimension=train_config.dimension, output_variance=train_config.output_variance,
                                     batch_size=sample_size, device=train_config.lsa_config.device)
    
    # Sample xs, ys, and take out the last one
    xs = task.sample_xs(train_config.sequence_length) # (batch_size, n_points + 1, dimension)
    _, _, ys = task.evaluate(xs) # (batch_size, n_points + 1, dimension)
    xs = xs[:, :-1, :] # (batch_size, n_points, dimension)
    ys = ys[:, :-1, :] # (batch_size, n_points, 1)

    # Compute sum_i (x_i y_i)
    sum_over_points = xs * ys # (batch_size, n_points, dimension)
    sum_over_points = torch.sum(sum_over_points, dim=1) # (batch_size, dimension)
    sum_over_points = torch.unsqueeze(sum_over_points, 1) # (batch_size, 1, dimension)

    # Take dot product with itself
    dot_products = torch.bmm(sum_over_points, torch.transpose(sum_over_points, 1, 2)) # (batch_size, 1, 1)
    dot_products = torch.squeeze(dot_products, 1) # (batch_size, 1)
    dot_products = torch.squeeze(dot_products, 1) # (batch_size,)
    return torch.mean(dot_products).item()

def theoretical_lr(train_config, sample_size=10000):
    numerator = sample_numerator(train_config, sample_size)
    denominator = sample_denominator(train_config, sample_size)
    return numerator/denominator