import math
from numpy import real
from models.mlp import MLP
from models.model_utils import forward_with_checkpoint, hash_tensor, get_scalers_from_data_path
from models.model_wrapper import ModelWrapper
from models.retrieval_wrapper import RetrievalAgent
import torch.nn as nn
import torch
import torch.nn.functional as F
from logging_utils import logger
import torch._dynamo
torch._dynamo.config.cache_size_limit = 64

BATCHED_RETRIEVAL = True

try:
    # If kernprof is running, profile will be available as builtin
    profile
except NameError:
    # Otherwise import no-op version
    from nn_util import profile

class DANWrapper(ModelWrapper):
    def __init__(self, wrapped: nn.Module, env_cfg, policy_cfg, **kwargs):
        super(DANWrapper, self).__init__()

        self.wrapped = wrapped
        self.retrieval_agent = RetrievalAgent(env_cfg, policy_cfg)

        self.mixed = env_cfg.get("mixed", False) and env_cfg['retrieval']['demo_pkl'] != env_cfg['delta_state']['demo_pkl']

        self.input_splits = []
        self.input_splits.append(len(self.retrieval_agent.agent.datasets['retrieval'].obs_matrix[0][0]))
        self.input_splits.append(len(self.retrieval_agent.agent.datasets['delta_state'].obs_matrix[0][0]))
        self.input_splits = torch.cumsum(torch.tensor(self.input_splits), dim=0)
        logger.info(f"DAN input splits: {self.input_splits}")

        self.s_dataset = self.retrieval_agent.agent.datasets['state'].flattened_obs_matrix
        self.a_dataset = self.retrieval_agent.agent.datasets['state'].flattened_act_matrix
        self.state_size = len(self.retrieval_agent.agent.datasets['state'].flattened_obs_matrix[0])
        self.action_size = len(self.retrieval_agent.agent.datasets['state'].flattened_act_matrix[0])
        self.delta_s_dataset = self.retrieval_agent.agent.datasets['delta_state'].flattened_obs_matrix
        self.delta_s_size = self.delta_s_dataset.shape[-1]
        self.delta_s_scaler = self.retrieval_agent.agent.datasets['delta_state'].obs_scaler
        self.combined_dataset = torch.cat([self.s_dataset, self.a_dataset, self.delta_s_dataset], dim=-1).contiguous()
        self.neighbor_batch = policy_cfg.get("neighbor_batch", -1)
        self.is_diffusion = kwargs.get("diffusion", False)

        # For ablations, not recommended to stray from defaults
        self.use_delta = kwargs.get("use_delta", True)
        self.use_delta_magnitude = kwargs.get("use_delta_magnitude", False)
        self.include_action = kwargs.get("include_action", True)
        self.random_neighbors = kwargs.get("random_neighbors", False)
        self.ic_regularizer = kwargs.get("ic_regularizer", False)
        self.ic_regularizer_lambda = kwargs.get("ic_regularizer_lambda", 0.0)
        self.ic_regularizer_bandwidth = kwargs.get("ic_regularizer_bandwidth", -1)
        self.residual = kwargs.get("residual", False)
        self.bc_architecture = kwargs.get("bc_architecture", False)
        self.weighted_sum = kwargs.get("weighted_sum", False)
        self.scalar_output = kwargs.get("scalar_output", False)
        self.normalize_scalar = kwargs.get("normalize_scalar", False)

        if self.weighted_sum:
            _, self.action_scaler = get_scalers_from_data_path(env_cfg['demo_pkl'])
            self.action_scaler.to_device(self.retrieval_agent.agent.device)

        if self.normalize_scalar:
            assert self.scalar_output

        if self.weighted_sum and not self.scalar_output:
            import numpy as np
            self.output_scaler.mean_np = np.zeros_like(self.output_scaler.mean_np)
            self.output_scaler.scale_np = np.ones_like(self.output_scaler.scale_np)
            self.output_scaler.mean_torch = torch.as_tensor(self.output_scaler.mean_np)
            self.output_scaler.scale_torch = torch.as_tensor(self.output_scaler.scale_np)
            self.output_scaler.to_device(self.retrieval_agent.agent.device)

        if self.bc_architecture:
            # self.unscaled_s_dataset = self.input_scaler.inverse_transform(self.retrieval_agent.agent.datasets['state'].flattened_obs_matrix)
            pass


        if not self.include_action:
            self.combined_dataset = torch.cat([self.s_dataset, self.delta_s_dataset], dim=-1).contiguous()

        # Will be set in factory if needed
        self.use_deep_set = False
        self.use_set_transformer = False
        self.permutation_dependent = kwargs.get("permutation_dependent", False)
        if self.permutation_dependent:
            self.permutation_dependent_aggregator = MLP(**{"input_len": self.action_size * self.retrieval_agent.num_neighbors, "output_len": self.action_size, "device": self.retrieval_agent.agent.device, "hidden_dims": [256, 256, 256], "batch_norm": True})

        # Validation flags etc.
        self.validation = False
        self.val_start_index = -1
        self.val_delta_s_dataset = []

        # Purely for debugging/visualization
        self.save_deltas = kwargs.get("save_deltas", False)
        self.save_queries = kwargs.get("save_queries", False)
        self.deltas = []
        self.queries = []

    def prepare_to_train(self, data_loader):
        if self.validation:
            self.val_start_index = torch.tensor(list(self.retrieval_agent.cache.keys())).max().item() + 1
            index_offset = self.val_start_index
            all_indices = []
        else:
            index_offset = 0

        all_neighbors = []
        all_delta_state = []
        batch_sampler = data_loader.batch_sampler
        dataset = data_loader.dataset
        for batch_indices in batch_sampler:
            input = torch.stack([dataset[i][0] for i in batch_indices])
            indices = list(batch_indices)

            if self.retrieval_agent.lookback == 1:
                input = input.unsqueeze(dim=1)
            if self.mixed == False:
                input = input.repeat(1, 1, 2)
            input = input.to(self.retrieval_agent.agent.device)
            
            if self.validation:
                all_indices.extend(indices)

            assert input.shape[2] == self.input_splits[-1]

            retrieval_state = input[:, :, 0:self.input_splits[0]]
            delta_state = input[:, -1, self.input_splits[0]:self.input_splits[1]]

            for i in range(len(indices)):
                indices[i] += index_offset

            if BATCHED_RETRIEVAL:
                self.retrieval_agent.cache_result_for_train(retrieval_state, indices)
            else:
                for r_state, i in zip(retrieval_state, indices):
                    self.retrieval_agent.cache_result_for_train(r_state, [i])

            if self.validation:
                self.val_delta_s_dataset.extend(delta_state)
            # all_neighbors.append(neighbors)
            # all_delta_state.append(query_state['delta_state'])
        #print(hash_tensor(torch.stack(list(self.retrieval_agent.cache.values()))))

        # We need to build the validation delta s dataset
        if self.validation:
            stacked_dataset = torch.stack(self.val_delta_s_dataset)
            self.val_delta_s_dataset = torch.empty_like(stacked_dataset)

            # Then reorder
            self.val_delta_s_dataset[torch.tensor(all_indices)] = stacked_dataset
            self.val_delta_s_dataset = self.delta_s_scaler.transform(self.val_delta_s_dataset)

        # all_neighbors = torch.stack(all_neighbors).flatten()
        # s = self.s_dataset[all_neighbors]
        # a = self.a_dataset[all_neighbors]
        # delta_s_lhs = self.delta_s_dataset[all_neighbors]
        # delta_s_rhs = torch.stack(all_delta_state).unsqueeze(dim=1).repeat(1, self.num_neighbors, 1).view(-1, all_delta_state[0].shape[-1])
        # delta_s = delta_s_lhs - delta_s_rhs
        # self.input_scaler.mean_torch[:s.shape[-1]] = torch.mean(s, axis=0)
        # self.input_scaler.scale_torch[:s.shape[-1]] = torch.std(s, axis=0)
        # self.input_scaler.mean_torch[s.shape[-1]:s.shape[-1] + a.shape[-1]] = torch.mean(a, axis=0)
        # self.input_scaler.scale_torch[s.shape[-1]:s.shape[-1] + a.shape[-1]] = torch.std(a, axis=0)
        #
        # self.input_scaler.mean_torch[-delta_s.shape[-1]:] = torch.mean(delta_s, axis=0)
        # self.input_scaler.scale_torch[-delta_s.shape[-1]:] = torch.std(delta_s, axis=0)

        #self.retrieval_agent.agent = None

    #@torch._dynamo.disable
    @profile
    def forward(self, input):
        index_offset = self.val_start_index if self.validation else 0

        if (self.is_diffusion or self.ic_regularizer) and (self.wrapped.training or self.validation):
            # Each input in batch will be index + action, so split those up
            real_actions = input[:, -self.action_size:]

            # Extract indices back to a list of ints
            input = input[:, 0].to(torch.uint32).tolist()

        batch_size = len(input)
        if not (self.wrapped.training or self.validation) and self.retrieval_agent.lookback == 1:
            input = input.unsqueeze(dim=1)
        if self.mixed == False and not (self.wrapped.training or self.validation):
            input = input.repeat(1, 1, 2)

        all_neighbors = []
        all_delta_state = []
        if self.wrapped.training or self.validation:
            for i in input:
                if self.random_neighbors:
                    neighbors = torch.randint(0, len(self.combined_dataset), size=(self.retrieval_agent.num_neighbors,), dtype=torch.int32)
                else:
                    neighbors = self.retrieval_agent.cache[i + index_offset]

                all_neighbors.append(neighbors)
                if self.validation:
                    all_delta_state.append(self.val_delta_s_dataset[i])
                else:
                    all_delta_state.append(self.delta_s_dataset[i])
        else:
            assert input.shape[2] == self.input_splits[-1]

            if self.bc_architecture:
                return self.wrapped(input[:, -1, 0:self.input_splits[0]])

            retrieval_state = input[:, :, 0:self.input_splits[0]]
            delta_state = input[:, -1, self.input_splits[0]:self.input_splits[1]]

            if BATCHED_RETRIEVAL:
                if self.random_neighbors:
                    neighbors = torch.randint(0, len(self.combined_dataset), size=(len(retrieval_state), self.retrieval_agent.num_neighbors), dtype=torch.int32)
                else:
                    neighbors = self.retrieval_agent.get_neighbors(retrieval_state)

                if self.save_queries:
                    self.queries.append(retrieval_state[:, -1])

                all_neighbors.extend(neighbors)
                all_delta_state.extend(delta_state)
            else:
                for r_state, d_state in zip(retrieval_state, delta_state):
                    neighbors = self.retrieval_agent.get_neighbors(r_state)
                    all_neighbors.append(neighbors)
                    all_delta_state.append(d_state)

        all_neighbors = torch.stack(all_neighbors).flatten()

        #print(f"Reduced input size by {100 - len(all_neighbors) / len(all_neighbors) * 100}%")
        all_data = self.combined_dataset[all_neighbors]

        if self.use_delta:
            delta_s_rhs = torch.stack(all_delta_state)
            if not (self.wrapped.training or self.validation):
                delta_s_rhs = self.delta_s_scaler.transform(delta_s_rhs)
            delta_s_rhs = delta_s_rhs.unsqueeze(dim=1).expand(len(input), self.retrieval_agent.num_neighbors, -1).reshape(-1, all_delta_state[0].shape[-1])

            all_data[:, -self.delta_s_size:] -= delta_s_rhs
            if self.use_delta_magnitude:
                delta_magnitude = torch.linalg.norm(all_data[:, :-self.delta_s_size], dim=1)
                delta_magnitude = delta_magnitude.unsqueeze(-1)
                all_data = torch.cat([all_data[:, :-self.delta_s_size], delta_magnitude], dim=-1)

        if self.save_deltas:
            self.deltas.append(all_data[:, -self.delta_s_size:])

        # [B, K, D_o] -> [B * K, D_o]
        inputs = all_data.view((-1, all_data.shape[-1]))

        if self.is_diffusion and (self.wrapped.training or self.validation):
            real_actions = real_actions.unsqueeze(dim=1).expand(len(input), self.retrieval_agent.num_neighbors, -1).reshape(-1, self.action_size)
            inputs = torch.cat((inputs, real_actions), dim=1)

            noise_loss = self.wrapped(inputs)
            return noise_loss

        if self.residual:
            all_actions = self.output_scaler.transform(self.a_dataset[all_neighbors]) + self.wrapped(inputs)
        else:
            if self.bc_architecture:
                all_actions = self.wrapped(inputs[:, :self.state_size])
            else:
                all_actions = self.wrapped(inputs)

        if self.ic_regularizer and (self.wrapped.training or self.validation):
            # Bandwidth is the median distance to neighbors
            if self.ic_regularizer_bandwidth == -1:
                bandwidth = torch.quantile(torch.linalg.norm(inputs[:, :-self.delta_s_size].reshape(len(input), self.num_neighbors, -1), axis=-1), 0.5, dim=1).unsqueeze(1).unsqueeze(1)
            else:
                bandwidth = self.ic_regularizer_bandwidth

            # Reshape data for batch processing
            neighbor_states_batch = inputs[:, :self.s_dataset.shape[-1]].reshape(
                batch_size, self.retrieval_agent.num_neighbors, -1
            )  # (batch_size, num_neighbors, state_dim)
            
            pred_actions_batch = all_actions.reshape(
                batch_size, self.retrieval_agent.num_neighbors, -1
            )  # (batch_size, num_neighbors, action_dim)
            
            # Batch compute all pairwise distances at once
            # This replaces the expensive loop of torch.cdist calls
            distances_batch = torch.cdist(neighbor_states_batch, neighbor_states_batch, p=2)
            # Shape: (batch_size, num_neighbors, num_neighbors)
            
            # Batch compute weights
            weights_batch = torch.exp(-0.5 * (distances_batch / bandwidth) ** 2)
            
            # Batch compute symmetric affinity matrices
            W_batch = 0.5 * (weights_batch + weights_batch.transpose(-2, -1))
            
            # Batch normalize rows
            row_sums_batch = W_batch.sum(dim=-1, keepdim=True)  # (batch_size, num_neighbors, 1)
            row_sums_batch = torch.where(row_sums_batch > 0, row_sums_batch, torch.ones_like(row_sums_batch))
            W_batch = W_batch / row_sums_batch
            
            # Batch compute pairwise action differences
            action_norms_sq = torch.sum(pred_actions_batch ** 2, dim=-1, keepdim=True)  # (B, K, 1)
            action_dots = torch.bmm(pred_actions_batch, pred_actions_batch.transpose(-2, -1))  # (B, K, K)
            squared_diffs_batch = action_norms_sq + action_norms_sq.transpose(-2, -1) - 2 * action_dots
            
            # Batch compute IC regularizers
            ic_regularizers = torch.sum(W_batch * squared_diffs_batch, dim=(-2, -1))  # (batch_size,)
            
            # Batch compute supervised losses
            if self.bc_architecture:
                # mean_pred_actions = self.wrapped(self.unscaled_s_dataset[input])
                mean_pred_actions = self.wrapped(self.input_scaler.inverse_transform(self.s_dataset[input]))
            else:
                mean_pred_actions = torch.mean(pred_actions_batch, dim=1)  # (batch_size, action_dim)
            supervised_losses = F.mse_loss(mean_pred_actions, real_actions, reduction='none')
            supervised_losses = supervised_losses.mean(dim=-1)

            # Combine losses
            total_losses = supervised_losses + self.ic_regularizer_lambda * ic_regularizers
            return torch.mean(total_losses)

        if self.use_deep_set or self.use_set_transformer:
            if not (self.wrapped.training or self.validation):
                all_actions = self.output_scaler.transform(all_actions)

            if self.use_deep_set:
                combined_actions = self.deep_set(all_actions)
            else:
                combined_actions = self.set_transformer(all_actions)

            # print(combined_actions)
            if not (self.wrapped.training or self.validation):
                combined_actions = self.output_scaler.inverse_transform(combined_actions)

            return combined_actions
        elif self.permutation_dependent:
            if not (self.wrapped.training or self.validation):
                all_actions = self.output_scaler.transform(all_actions)

            all_actions = all_actions.reshape(batch_size, -1)

            all_actions = self.permutation_dependent_aggregator(all_actions)

            if not (self.wrapped.training or self.validation):
                all_actions = self.output_scaler.inverse_transform(all_actions)

            return all_actions
        elif self.weighted_sum:
            all_actions = all_actions.reshape(batch_size, self.num_neighbors, -1)
            if self.normalize_scalar:
                all_actions /= all_actions.sum(axis=1, keepdim=True)

            neighbor_actions = inputs[:, self.state_size:self.action_size + self.state_size].reshape(batch_size, self.num_neighbors, -1)

            # if self.scalar_output and (self.wrapped.training or self.validation):
            all_actions = (all_actions * neighbor_actions).sum(axis=1)

            if self.wrapped.training or self.validation:
                all_actions = self.action_scaler.transform(all_actions)
            else:
                #breakpoint()
                pass

            return all_actions
        else:
            all_actions = all_actions.view(batch_size, self.retrieval_agent.num_neighbors, -1)

            return torch.mean(all_actions, axis=1)

    def to(self, *args, **kwargs):
        result = super().to(*args, **kwargs)

        self.retrieval_agent.agent.to_device(self.device)

        return result

    def compile(self):
        self.wrapped = torch.compile(self.wrapped, mode="reduce-overhead")
        if hasattr(self, "set_transformer"):
            self.set_transformer = torch.compile(self.set_transformer, mode="reduce-overhead")
