import torch
import torch.nn as nn
from models.model_utils import set_attributes_from_args
from models.retrieval_wrapper import RetrievalAgent
import math

class PositionalEncoding(nn.Module):
    """
    Standard sinusoidal positional encoding.
    From "Attention Is All You Need" (https://arxiv.org/abs/1706.03762).
    """
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class REGENT(nn.Module):
    def __init__(self, policy_cfg, **kwargs):
        DEFAULT_CONFIG = {
            # Required, no default
            'input_len': None,
            'output_len': None,
            'device': None,
            'env_cfg': None,

            'hidden_size': 256,
            '_lambda': 10.0,
            'dropout': 0.1,
            'n_layer': 3,
            'n_head': 4,
            'optimizer': 0
        }
        super(REGENT, self).__init__()
        set_attributes_from_args(self, DEFAULT_CONFIG, kwargs)

        self.state_encoder = nn.Linear(self.input_len, self.hidden_size).to(self.device)
        self.action_encoder = nn.Linear(self.output_len, self.hidden_size).to(self.device)

        self.pos_encoder = PositionalEncoding(self.hidden_size, self.dropout).to(self.device)
        encoder_layers = nn.TransformerEncoderLayer(self.hidden_size, self.n_head, self.hidden_size * 4, self.dropout, batch_first=True).to(self.device)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, self.n_layer).to(self.device)
        
        self.action_head = nn.Linear(self.hidden_size, self.output_len).to(self.device)

        self.init_weights()

        for param in self.action_head.parameters():
            param.optimizer = self.optimizer
        for param in self.pos_encoder.parameters():
            param.optimizer = self.optimizer
        for param in self.state_encoder.parameters():
            param.optimizer = self.optimizer
        for param in self.action_encoder.parameters():
            param.optimizer = self.optimizer
        for param in self.transformer_encoder.parameters():
            param.optimizer = self.optimizer

        self.retrieval_agent = RetrievalAgent(self.env_cfg, policy_cfg)

        self.mixed = self.env_cfg.get("mixed", False) and self.env_cfg['retrieval']['demo_pkl'] != self.env_cfg['delta_state']['demo_pkl']

        self.input_splits = []
        self.input_splits.append(len(self.retrieval_agent.agent.datasets['retrieval'].obs_matrix[0][0]))
        self.input_splits.append(len(self.retrieval_agent.agent.datasets['delta_state'].obs_matrix[0][0]))
        self.input_splits = torch.cumsum(torch.tensor(self.input_splits), dim=0)

        self.s_dataset = self.retrieval_agent.agent.datasets['state'].flattened_obs_matrix
        self.a_dataset = self.retrieval_agent.agent.datasets['state'].flattened_act_matrix
        self.action_size = len(self.retrieval_agent.agent.datasets['state'].flattened_act_matrix[0])
        self.s_scaler = self.retrieval_agent.agent.datasets['state'].obs_scaler
        self.output_scaler = self.retrieval_agent.agent.datasets['state'].act_scaler
        self.combined_dataset = torch.cat([self.s_dataset, self.a_dataset], dim=-1).contiguous()
        self.seq_len = 2 * self.retrieval_agent.num_neighbors + 1
        self.causal_mask = self._generate_causal_mask(self.seq_len).to(self.device)

        self.validation = False
        self.val_start_index = -1
        self.val_s_dataset = []

    def prepare_to_train(self, data_loader):
        if self.validation:
            self.val_start_index = torch.tensor(list(self.retrieval_agent.cache.keys())).max().item() + 1
            index_offset = self.val_start_index
            all_indices = []
        else:
            index_offset = 0

        batch_sampler = data_loader.batch_sampler
        dataset = data_loader.dataset
        for batch_indices in batch_sampler:
            input = torch.stack([dataset[i][0] for i in batch_indices])
            indices = list(batch_indices)

            if self.retrieval_agent.lookback == 1:
                input = input.unsqueeze(dim=1)
            if self.mixed == False:
                input = input.repeat(1, 1, 2)
            input = input.to(self.retrieval_agent.agent.device)
            
            if self.validation:
                all_indices.extend(indices)

            assert input.shape[2] == self.input_splits[-1]

            retrieval_state = input[:, :, 0:self.input_splits[0]]

            for i in range(len(indices)):
                indices[i] += index_offset

            self.retrieval_agent.cache_result_for_train(retrieval_state, indices)

            if self.validation:
                self.val_s_dataset.extend(retrieval_state)

        if self.validation:
            stacked_dataset = torch.stack(self.val_s_dataset)
            self.val_s_dataset = torch.empty_like(stacked_dataset)

            # Then reorder
            self.val_s_dataset[torch.tensor(all_indices)] = stacked_dataset
            self.val_s_dataset = self.s_scaler.transform(self.val_s_dataset)

    def init_weights(self) -> None:
        """Initialize weights for the linear layers."""
        initrange = 0.1
        self.state_encoder.weight.data.uniform_(-initrange, initrange)
        self.action_encoder.weight.data.uniform_(-initrange, initrange)
        self.action_head.weight.data.uniform_(-initrange, initrange)
        self.state_encoder.bias.data.zero_()
        self.action_encoder.bias.data.zero_()
        self.action_head.bias.data.zero_()

    def _generate_causal_mask(self, sz: int) -> torch.Tensor:
        """Generates a causal mask of size sz x sz."""
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, input):
        index_offset = self.val_start_index if self.validation else 0

        batch_size = len(input)
        if not (self.state_encoder.training or self.validation) and self.retrieval_agent.lookback == 1:
            input = input.unsqueeze(dim=1)
        if self.mixed == False and not (self.state_encoder.training or self.validation):
            input = input.repeat(1, 1, 2)

        all_neighbors = []
        all_deltas = []
        all_query_states = []
        if self.state_encoder.training or self.validation:
            for i in input:
                neighbors, deltas = self.retrieval_agent.cache[i + index_offset]
                all_neighbors.append(neighbors)
                all_deltas.append(deltas)
                if self.validation:
                    all_query_states.append(self.val_s_dataset[i])
                else:
                    all_query_states.append(self.s_dataset[i])
        else:
            assert input.shape[2] == self.input_splits[-1]

            retrieval_state = input[:, :, 0:self.input_splits[0]]

            neighbors, deltas = self.retrieval_agent.get_neighbors(retrieval_state)

            all_neighbors.extend(neighbors)
            all_query_states.extend(self.s_scaler.transform(retrieval_state[:, -1]))
            all_deltas.extend(deltas)

        query_state = torch.stack(all_query_states).to(self.device)
        all_deltas = torch.stack(all_deltas).to(self.device).sort(axis=1)[1]
        neighbors_indices = torch.gather(torch.stack(all_neighbors).to(self.device), 1, all_deltas)

        neighbor_data = self.combined_dataset[neighbors_indices.flatten()].view(
            batch_size, self.retrieval_agent.num_neighbors, -1
        ).to(self.device)

        neighbor_states, neighbor_actions = torch.split(
            neighbor_data, [self.input_len, self.output_len], dim=2
        )

        # neighbor_states = self.s_scaler.transform(neighbor_states)
        neighbor_actions = self.output_scaler.transform(neighbor_actions)

        s_nearest = neighbor_states[:, 0, :]
        a_R_and_P = neighbor_actions[:, 0, :]

        state_embeds = self.state_encoder(torch.cat([neighbor_states, query_state.unsqueeze(1)], dim=1))
        action_embeds = self.action_encoder(neighbor_actions)

        input_seq = torch.empty(batch_size, self.seq_len, self.hidden_size, device=self.device)
        input_seq[:, 0::2, :] = state_embeds
        input_seq[:, 1::2, :] = action_embeds

        input_seq = self.pos_encoder(input_seq)

        if self.causal_mask is None or self.causal_mask.size(0) != self.seq_len:
            self.causal_mask = self._generate_causal_mask(self.seq_len).to(self.device)

        transformer_output = self.transformer_encoder(input_seq, mask=self.causal_mask)

        query_output_embedding = transformer_output[:, -1, :]
        a_transformer = self.action_head(query_output_embedding)

        dist = torch.norm(query_state - s_nearest, p=2, dim=1, keepdim=True)
        
        weight = torch.exp(-self._lambda * dist)

        final_action = weight * a_R_and_P + (1 - weight) * a_transformer

        if not (self.state_encoder.training or self.validation):
            neighbor_actions = self.output_scaler.inverse_transform(neighbor_actions)

        return final_action

    def compile(self):
        self.action_head = torch.compile(self.action_head, mode="reduce-overhead")
        self.pos_encoder = torch.compile(self.pos_encoder, mode="reduce-overhead")
        self.state_encoder = torch.compile(self.state_encoder, mode="reduce-overhead")
        self.action_encoder = torch.compile(self.action_encoder, mode="reduce-overhead")
        self.transformer_encoder = torch.compile(self.transformer_encoder, mode="reduce-overhead")
