from tensorflow_datasets import nearest_neighbors
from models.model_utils import hash_tensor
from models.retrieval_wrapper import RetrievalAgent
from nn_agent_torch import NNAgentEuclideanStandardized
import torch.nn as nn
import math
import torch
import pickle
from logging_utils import logger
import torch._dynamo
torch._dynamo.config.cache_size_limit = 64

BATCHED_LWR=True

try:
    # If kernprof is running, profile will be available as builtin
    profile
except NameError:
    # Otherwise import no-op version
    from nn_util import profile

class LWR(nn.Module):
    def __init__(self, env_cfg, policy_cfg):
        super(LWR, self).__init__()

        self.retrieval_agent = RetrievalAgent(env_cfg, policy_cfg)
        self.device = self.retrieval_agent.agent.device

    def forward(self, input):
        if self.retrieval_agent.lookback == 1:
            input = input.unsqueeze(dim=1)
    
        assert input.ndim == 3

        device = self.retrieval_agent.agent.device
        neighbors, deltas = self.retrieval_agent.get_neighbors(input)
        weights = 1 / deltas
        weights = weights / torch.sum(weights)

        obs = self.retrieval_agent.agent.datasets['retrieval'].flattened_obs_matrix[neighbors].double()

        input = self.retrieval_agent.agent.datasets['retrieval'].obs_scaler.transform(input[:, -1])

        if BATCHED_LWR:
            X = torch.empty((*obs.shape[:-1], obs.shape[-1] + 1), device=device).double()
            X[:, :, 0] = 1  # First column of ones
            X[:, :, 1:] = obs

            Y = self.retrieval_agent.agent.datasets['retrieval'].flattened_act_matrix[neighbors].double()
            X_weights = torch.swapaxes(X, 1, 2) * weights.double().unsqueeze(1)

            try:
                theta = torch.linalg.lstsq(X_weights @ X, X_weights @ Y)[0]
            except RuntimeError:
                try:
                    print("FAILED TO CONVERGE, ADDING NOISE")
                    theta = torch.pinverse(X_weights @ (X + 1e-8)) @ (X_weights @ Y)
                except:
                    print("Something went wrong, likely a very large number (> e+150) was encountered. Returning arbitrary action.")
                    return self.retrieval_agent.agent.datasets['retrieval'].act_matrix[0][0]

            current_ob_with_one = torch.cat([torch.ones(len(input), device=device).unsqueeze(-1), input], dim=-1).double()
            return torch.bmm(torch.swapaxes(theta, 1, 2), current_ob_with_one.unsqueeze(-1)).squeeze(2)
        else:
            batch_size = obs.shape[0]
            results = []

            for i in range(batch_size):
                obs_i = obs[i]
                
                Y_i = self.retrieval_agent.agent.datasets['retrieval'].flattened_act_matrix[neighbors[i]].double()
                
                deltas_i = deltas[i].double()

                X_i = torch.empty((obs_i.shape[0], obs_i.shape[1] + 1), device=device).double()
                X_i[:, 0] = 1.0  # Bias term
                X_i[:, 1:] = obs_i
                X_weights_i = X_i.T * deltas_i

                A_i = X_weights_i @ X_i
                B_i = X_weights_i @ Y_i

                try:
                    theta_i = torch.linalg.lstsq(A_i, B_i).solution
                
                except torch.linalg.LinAlgError:
                    print(f"FAILED TO CONVERGE for batch item {i}, ADDING NOISE")
                    try:
                        A_i_regularized = X_weights_i @ (X_i + 1e-8)
                        theta_i = torch.pinverse(A_i_regularized) @ B_i
                    except:
                        print(f"CRITICAL FAILURE for batch item {i}. Returning arbitrary action.")
                        arbitrary_action = self.retrieval_agent.agent.datasets['retrieval'].act_matrix[0][0]
                        results.append(arbitrary_action)
                        continue # Move to the next item in the batch

                current_ob_with_one_i = torch.cat([torch.tensor([1.0], device=device).double(), input[i].double()])
                prediction_i = theta_i.T @ current_ob_with_one_i
                
                results.append(prediction_i)

            return torch.stack(results, dim=0)

    def to(self, *args, **kwargs):
        result = super().to(*args, **kwargs)

        new_device = None
        if args:
            if isinstance(args[0], (torch.device, str, int)):
                new_device = torch.device(args[0])
        elif 'device' in kwargs:
            new_device = torch.device(kwargs['device'])

        if new_device:
            self.retrieval_agent.agent.to_device(new_device)

        return result
