"""
Follows the structure in:
https://github.com/dtsip/in-context-learning/blob/main/src/tasks.py
https://github.com/dtsip/in-context-learning/blob/main/src/samplers.py

NoisyLinearRegressionTask
- Consists of a randomly initialized weight vector, which can be
used to evaluate any x.
"""

import torch
import torch.nn as nn
import math
import LSA_layer
import dotmap

class NoisyLinearRegressionTask:
    def __init__(self, dimension, output_variance, batch_size, device):
        self.dimension = dimension
        self.output_variance = output_variance
        self.batch_size = batch_size
        self.device = device

        # Initialize the weight vector
        self.weight_vectors = torch.randn(batch_size, self.dimension, 1, device=device)
    
    def sample_xs(self, sequence_length):
        xs = torch.empty(self.batch_size, sequence_length, self.dimension, device=self.device)
        nn.init.uniform_(xs, a=-1, b=1)
        return xs
    
    # xs is of shape (batch_size, sequence_length, dimension)
    # Returns a pair consisting of the input and the desired
    # output of the linear self-attention layer. Also returns
    # the ys without any reshaping.
    def evaluate(self, xs):

        # Obtain y - shape (batch_size, sequence_length, 1)
        y = torch.bmm(xs, self.weight_vectors)
        output_noise = math.sqrt(self.output_variance) * torch.randn(y.size(), device=self.device)
        y = y + output_noise

        # Separate the last_y from the rest of the ys
        # last_y is of shape (batch_size,)
        last_y = y[:, -1, :]  # (batch_size, 1)
        last_y = torch.squeeze(last_y, dim=-1) # (batch_size, )
        last_y = last_y.clone()

        # Create the context
        # context is of size (batch_size, sequence_length, point_dim + 1)
        y[:, -1, :] = 0
        context = torch.cat([xs, y], dim=-1)
        return context, last_y, y

if __name__ == "__main__":
    device = None
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    noisy_lr_task = NoisyLinearRegressionTask(dimension=20, output_variance=0.001, batch_size=2048, device=device)
    xs = noisy_lr_task.sample_xs(sequence_length=20)
    context, last_y = noisy_lr_task.evaluate(xs)
    print(context.size())
    print(last_y.size())
    print("Last y in context tensor: ", context[:, -1, -1])
    print("Last y: ", last_y)

    config = dotmap.DotMap()
    config.point_dim = 21
    config.query_dim = 20
    config.head_dim = 21
    config.num_heads = 1
    config.device = device
    lsa_layer = LSA_layer.LinearSelfAttention(config=config)
    predicted_ys = lsa_layer(context)
    predicted_ys = predicted_ys[:, -1, -1]
    print(predicted_ys.size())
    print("Predicted ys: ", predicted_ys)

