import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import os
import numpy as np
import copy
from open_spiel.python import rl_agent
from utls import *

class DQNAgent:
    def __init__(self, params):
        """
        params: dict, contains the parameters for the agent
        """
        print("Initializing " + type(self).__name__ + " with parameters:", params)
        self.game = params.get('game', None)
        self.player_id = params.get('player_id', None) # None means agent needs to get the player_id from the environment
        self.num_actions = params.get('num_actions', 9)
        self.state_dim = params.get('state_dim', 27)
        self.gamma = params.get('gamma', 0.99)
        self.epsilon_start = params.get('epsilon_start', 1)
        self.epsilon_decay_duration  = params.get('epsilon_decay_duration', int(1e5))
        self.epsilon_end = params.get('epsilon_end', 0.1)
        self.buffer_size = params.get('buffer_size', 10000)
        self.min_buffer_size_to_learn = params.get('min_buffer_size_to_learn', 1000)
        self.batch_size = params.get('batch_size', 128)
        self.d_model = params.get('d_model', 64)
        self.lr = params.get('lr', 0.001)
        self.train_interval = params.get('train_interval', 10) # every x step
        self.target_net_update_interval = params.get('target_net_update_interval', 1000) # every x step
        self.step_counter = 0
        self.train_counter = 0
        self.device = torch.device(params.get('device', "cuda"))
        self.buffer = self.DQNbuffer(self.num_actions, self.state_dim, self.buffer_size, self.batch_size, self.device)

        self.loss = nn.MSELoss(reduction='mean')
        self.loss_history = []
        self.q_max_history = []
        self.q_min_history = []
        self.q_mean_history = []
        self.last_move_buffer = {'id': None, 'state': None, 'action': None, 'next_state': None, 'reward': None, 'done': None}
        # for step in OpenSpiel, we need to store the data of the last step
        self.last_move_buffer2 = {'id': None, 'state': None, 'action': None, 'next_state': None, 'reward': None, 'done': None}
        # for self play, we need to store the data of the last step of both players


    def add_to_train_buff(self, state, action, reward, next_state, done):
        self.buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done)
        self.step_counter += 1
        if self.step_counter < self.min_buffer_size_to_learn:
            return
        if self.step_counter % self.train_interval == 0:
            self.train()
        if self.step_counter % self.target_net_update_interval == 0:
            self.update_target_net()

    def set_epsilon(self, epsilon):
        self.epsilon = epsilon

    def get_epsilon(self, power=1.0):
        """
        Returns the evaluation or decayed epsilon value. 
        modified from open_spiel/open_spiel/python/jax/dqn.py
        """
        decay_steps = min(self.step_counter, self.epsilon_decay_duration)
        decayed_epsilon = (
            self.epsilon_end + (self.epsilon_start - self.epsilon_end) *
            (1 - decay_steps / self.epsilon_decay_duration) ** power)
        return decayed_epsilon

    def choose_action(self, state, deterministic=False):
        raise NotImplementedError
    
    def train(self):
        states, next_states, actions, rewards, dones = self.buffer.get()
        q_pred = self.policy_net(states)[np.arange(self.batch_size), actions]
        with torch.no_grad():
            next_actions = torch.argmax(self.policy_net(next_states), dim=1)
            q_next_pred = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        q_target = rewards + self.gamma * q_next_pred * (1 - dones)
        loss = self.loss(q_pred, q_target).to(self.device)
        loss.backward()
        self.optimizer.step()
        with torch.no_grad():
            self.target_net.head_mlp.weight.clamp_(0.01, 2.0)
        self.optimizer.zero_grad()

        self.train_counter += 1
        self.loss_history.append(loss.item())
        self.q_max_history.append(q_pred.max().item())
        self.q_min_history.append(q_pred.min().item())
        self.q_mean_history.append(q_pred.mean().item())

    def update_target_net(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
    
    def save_model(self, path='model.pth'):
        torch.save(self.policy_net.state_dict(), path)

    def load_model(self, path='model.pth'):
        self.policy_net.load_state_dict(torch.load(path, map_location=self.device))
        self.policy_net.to(self.device)

    def step(self, time_step, is_evaluation=False):
        '''
        time_step: open_spiel.python.TimeStep
        '''
        # for OpenSpiel, TTT and Breakthrough, info_state[0] = info_state[1]
        # So if you want to self play, you might need to mirror the state
        # exchange info_state[0][0:25] and info_state[0][25:50] for Breakthrough
        # Right now, the agent can only play as player 0
        if self.player_id is None:
            current_player = time_step.observations["current_player"]
            if time_step.last():
                if is_evaluation:
                    return
                self.last_move_buffer2['next_state'] = np.ones(self.num_actions + self.state_dim, dtype=np.float32) * 1
                    # there must have legal actions, otherwise the net will output NaN
                self.last_move_buffer2['reward'] = time_step.rewards[self.last_move_buffer2['id']]
                self.last_move_buffer2['done'] = 1
                self.add_to_train_buff(
                    state=self.last_move_buffer2['state'],
                    action=self.last_move_buffer2['action'],
                    next_state=self.last_move_buffer2['next_state'],
                    reward=self.last_move_buffer2['reward'],
                    done=self.last_move_buffer2['done']
                    )
                self.last_move_buffer2['state'] = None # reset for the next trajectory

                self.last_move_buffer['next_state'] = np.ones(self.num_actions + self.state_dim, dtype=np.float32) * 1
                # there must have legal actions, otherwise the net will output NaN
                self.last_move_buffer['reward'] = time_step.rewards[self.last_move_buffer2['id']]
                self.last_move_buffer['done'] = 1
                self.add_to_train_buff(
                    state=self.last_move_buffer['state'],
                    action=self.last_move_buffer['action'],
                    next_state=self.last_move_buffer['next_state'],
                    reward=self.last_move_buffer['reward'],
                    done=self.last_move_buffer['done']
                    )
                self.last_move_buffer['state'] = None # reset for the next trajectory
                return
            legal_actions = time_step.observations["legal_actions"][current_player]
            
            state = state_vec_to_player_observation(game_name=self.game, time_step=time_step, player_id=current_player)
            legal_mask = np.zeros(self.num_actions, dtype=np.float32)
            legal_mask[legal_actions] = 1.0
            state = np.concatenate((legal_mask, state), axis=0)  # (num_actions + state_dim, )
            if is_evaluation:
                action = self.choose_action(state=state, deterministic=True)
                probs = np.zeros(self.num_actions)
                probs[action] = 1.0
                return rl_agent.StepOutput(action=action, probs=probs)
            action = self.choose_action(state)
            probs = np.zeros(self.num_actions)
            probs[action] = 1.0
            if self.last_move_buffer['state'] is not None:
                self.last_move_buffer['next_state'] = state
                self.last_move_buffer['done'] = 0
                if self.last_move_buffer2['state'] is not None:
                    self.last_move_buffer2['reward'] = time_step.rewards[current_player]
                    if not is_evaluation:
                        self.add_to_train_buff(
                            state=self.last_move_buffer2['state'],
                            action=self.last_move_buffer2['action'],
                            next_state=self.last_move_buffer2['next_state'],
                            reward=self.last_move_buffer2['reward'],
                            done=self.last_move_buffer2['done']
                        )
                self.last_move_buffer2 |= self.last_move_buffer
            self.last_move_buffer['id'] = current_player
            self.last_move_buffer['state'] = state
            self.last_move_buffer['action'] = action

            return rl_agent.StepOutput(action=action, probs=probs)

        else:
            if time_step.last():
                if is_evaluation:
                    return
                self.last_move_buffer['next_state'] = np.ones(self.num_actions + self.state_dim, dtype=np.float32) * 1
                # there must have legal actions, otherwise the net will output NaN
                self.last_move_buffer['reward'] = time_step.rewards[self.player_id]
                self.last_move_buffer['done'] = 1
                self.add_to_train_buff(
                    state=self.last_move_buffer['state'],
                    action=self.last_move_buffer['action'],
                    next_state=self.last_move_buffer['next_state'],
                    reward=self.last_move_buffer['reward'],
                    done=self.last_move_buffer['done']
                    )
                self.last_move_buffer['state'] = None # reset for the next trajectory
                return 
            state = np.asarray(time_step.observations["info_state"][0])
            legal_actions = time_step.observations["legal_actions"][self.player_id]
            legal_mask = np.zeros(self.num_actions, dtype=np.float32)
            legal_mask[legal_actions] = 1.0
            state = np.concatenate((legal_mask, state), axis=0)  # (num_actions + state_dim, )
            # WARNING. mlp dqn does not use concatenated state.
            if is_evaluation:
                action = self.choose_action(state=state, deterministic=True)
                probs = np.zeros(self.num_actions)
                probs[action] = 1.0
                return rl_agent.StepOutput(action=action, probs=probs)
            action = self.choose_action(state)
            probs = np.zeros(self.num_actions)
            probs[action] = 1.0
            
            if self.last_move_buffer['state'] is not None:
                self.last_move_buffer['next_state'] = state
                self.last_move_buffer['reward'] = time_step.rewards[self.player_id]
                self.last_move_buffer['done'] = 0
                if not is_evaluation:
                    self.add_to_train_buff(
                        state=self.last_move_buffer['state'],
                        action=self.last_move_buffer['action'],
                        next_state=self.last_move_buffer['next_state'],
                        reward=self.last_move_buffer['reward'],
                        done=self.last_move_buffer['done']
                    )
            self.last_move_buffer['state'] = state
            self.last_move_buffer['action'] = action
            return rl_agent.StepOutput(action=action, probs=probs)

    class DQNbuffer:
        def __init__(self, num_actions, state_dim, buffer_size, batch_size, device):
            self.batch_size = batch_size
            self.device = device
            self.count = 0
            self.index = 0
            self.states = np.zeros((buffer_size, state_dim + num_actions), dtype=np.float32)
            self.next_states = np.zeros((buffer_size, state_dim + num_actions), dtype=np.float32)
            self.actions = np.zeros((buffer_size,), dtype=np.int32)
            self.rewards = np.zeros((buffer_size,), dtype=np.float32)
            self.dones = np.zeros((buffer_size,), dtype=np.float32)

        def add(self, state, action, next_state, reward, done):
            self.states[self.index, :] = np.asarray(state)
            self.next_states[self.index, :] = np.asarray(next_state)
            self.actions[self.index] = action
            self.rewards[self.index] = reward
            self.dones[self.index] = done
            self.count += 1
            self.index += 1
            self.index = self.index % len(self.rewards)

        def clear(self):
            self.count = 0
            self.index = 0

        def get(self):
            assert self.count >= self.batch_size, "Not enough data in buffer"
            indices = np.random.choice(min(self.count, len(self.rewards)), self.batch_size)
            states = torch.tensor(self.states[indices, :], dtype=torch.float32, device=self.device)
            next_states = torch.tensor(self.next_states[indices, :], dtype=torch.float32, device=self.device)
            actions = torch.tensor(self.actions[indices], dtype=torch.long, device=self.device)
            rewards = torch.tensor(self.rewards[indices], dtype=torch.float32, device=self.device)
            dones = torch.tensor(self.dones[indices], dtype=torch.float32, device=self.device)  # 1.0 if done else 0.0
            return states, next_states, actions, rewards, dones

class Unmasked_DQNAgent(DQNAgent):
    def __init__(self, params):
        super().__init__(params)
        self.policy_net = torch.nn.Sequential(
            torch.nn.Linear(self.state_dim, self.d_model),  
            torch.nn.ReLU(),
            torch.nn.Linear(self.d_model, self.d_model),
            torch.nn.ReLU(),
            torch.nn.Linear(self.d_model, self.num_actions)
        ).to(self.device)
        self.target_net = copy.deepcopy(self.policy_net)
        self.target_net.eval()
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)
       
    def train(self):
        # DDQN: next action is selected by policy_net, value is from target_net
        states, next_states, actions, rewards, dones = self.buffer.get()
        states = states[:, self.num_actions:]
        self.optimizer.zero_grad()
        q_pred = self.policy_net(states)[np.arange(self.batch_size), actions]
        with torch.no_grad():
            q_next_policy = self.policy_net(next_states[:, self.num_actions:])
            next_actions = torch.argmax(q_next_policy, dim=1)
            q_next_target = self.target_net(next_states[:, self.num_actions:])
            q_next_pred = q_next_target.gather(1, next_actions.unsqueeze(1)).squeeze(1)
        q_target = rewards + self.gamma * q_next_pred * (1 - dones)
        loss = self.loss(q_pred, q_target).to(self.device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.train_counter += 1
        self.loss_history.append(loss.item())
        self.q_max_history.append(q_pred.max().item())
        self.q_min_history.append(q_pred.min().item())
        self.q_mean_history.append(q_pred.mean().item())

    def choose_action(self, state, deterministic=False):
        legal_mask = state[:self.num_actions]
        legal_actions = np.nonzero(legal_mask)[0]
        if not deterministic and random.random() < self.get_epsilon():
            action = random.choice(legal_actions)
        else:
            state_tensor = torch.tensor(state[self.num_actions:], dtype=torch.float32).to(self.device)
            with torch.no_grad():
                q_vals = self.policy_net(state_tensor).flatten()
            legal_mask_tensor = torch.tensor(legal_mask, dtype=torch.bool, device=self.device)
            q_vals.masked_fill_(~legal_mask_tensor, float('-inf'))
            action = torch.argmax(q_vals).item()
        return action

class Masked_DQNAgent(Unmasked_DQNAgent):
    def __init__(self, params):
        super().__init__(params)
    
    def train(self):
        # DDQN: next action is selected by policy_net, value is from target_net
        states, next_states, actions, rewards, dones = self.buffer.get()
        next_legal_mask = next_states[:, :self.num_actions]
        states = states[:, self.num_actions:]
        self.optimizer.zero_grad()
        q_pred = self.policy_net(states)[np.arange(self.batch_size), actions]
        with torch.no_grad():
            q_next_policy = self.policy_net(next_states[:, self.num_actions:])
            q_next_policy[~next_legal_mask.bool()] = float('-inf')
            next_actions = torch.argmax(q_next_policy, dim=1)
            q_next_target = self.target_net(next_states[:, self.num_actions:])
            q_next_pred = q_next_target.gather(1, next_actions.unsqueeze(1)).squeeze(1)
        q_target = rewards + self.gamma * q_next_pred * (1 - dones)
        loss = self.loss(q_pred, q_target).to(self.device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.train_counter += 1
        self.loss_history.append(loss.item())
        self.q_max_history.append(q_pred.max().item())
        self.q_min_history.append(q_pred.min().item())
        self.q_mean_history.append(q_pred.mean().item())

class ResNetDQNAgent(Masked_DQNAgent):
    def __init__(self, params: dict = None):
        super().__init__(params)
        self.in_channels = params.get('num_layers', 4)
        self.height = int(math.sqrt((self.state_dim // self.in_channels)))
        self.width = int(math.sqrt((self.state_dim // self.in_channels)))
        self.num_res_blocks = params.get('num_res_blocks', 6)
        
        self.policy_net = self.ResNet(
            num_actions=self.num_actions,
            in_channels=self.in_channels,
            height=self.height,
            width=self.width,
            num_res_blocks=self.num_res_blocks,
            d_model=self.d_model
        ).to(self.device)
        self.target_net = copy.deepcopy(self.policy_net)
        self.target_net.eval()
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)

    class ResNet(nn.Module):
        def __init__(self, num_actions, in_channels, height, width, num_res_blocks, d_model=128):
            super().__init__()
            self.num_actions = num_actions
            self.in_channels = in_channels
            self.height = height
            self.width = width
            self.stem = nn.Sequential(
                nn.Conv2d(in_channels, d_model, kernel_size=3, padding=1),
                nn.BatchNorm2d(d_model),
                nn.ReLU()
            )
            res_blocks = [self.ResidualBlock(d_model) for _ in range(num_res_blocks)]
            self.res_body = nn.Sequential(*res_blocks)

            with torch.no_grad():
                dummy_input = torch.zeros(1, in_channels, height, width)
                dummy_output = self.res_body(self.stem(dummy_input))
                flattened_size = dummy_output.view(1, -1).size(1)

            self.output = nn.Linear(flattened_size, self.num_actions)
            
        def forward(self, board_state):
            reshaped_board = board_state.view(
                -1, self.in_channels, self.height, self.width
            )
            stem_out = self.stem(reshaped_board)
            res_out = self.res_body(stem_out)
            flattened_output = res_out.view(res_out.size(0), -1)
            q_pred = self.output(flattened_output)
            return q_pred
        
        class ResidualBlock(nn.Module):
            def __init__(self, num_filters):
                super().__init__()
                self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
                self.bn1 = nn.BatchNorm2d(num_filters)
                self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
                self.bn2 = nn.BatchNorm2d(num_filters)

            def forward(self, x):
                residual = x
                out = F.relu(self.bn1(self.conv1(x)))
                out = self.bn2(self.conv2(out))
                out += residual  # The skip connection
                out = F.relu(out)
                return out
            
class LGAN_DQNAgent(DQNAgent):
    def __init__(self, params):
        super().__init__(params)
        self.num_heads = params.get('num_heads', 8)
        self.k_embedding_dim = params.get('k_embedding_dim', 16)
        self.policy_net = self.LegalAttentionNet(input_dim=self.state_dim, num_actions=self.num_actions,
                    d_model=self.d_model, num_heads=self.num_heads, k_embedding_dim=self.k_embedding_dim, 
                    num_layers=params.get('num_layers', 1), average_func=params.get('average_func', 'mlp')).to(self.device)
        self.target_net = copy.deepcopy(self.policy_net)
        self.target_net.eval()
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)

    def choose_action(self, state, deterministic=False):
        if not deterministic and random.random() < self.get_epsilon():
            legal_mask = state[:self.num_actions]
            legal_actions = np.nonzero(legal_mask)[0]
            action = random.choice(legal_actions)
        else:
            state = torch.tensor(state, dtype=torch.float32).to(self.device)
            with torch.no_grad():
                q_val = self.policy_net(state)
            action = torch.argmax(q_val).item()
        return action
    
    def train(self):
        states, next_states, actions, rewards, dones = self.buffer.get()

        self.optimizer.zero_grad()
        q_pred = self.policy_net(states)[np.arange(self.batch_size), actions]
        with torch.no_grad():
            next_actions = torch.argmax(self.policy_net(next_states), dim=1)
            q_next_pred = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1)
        q_target = rewards + self.gamma * q_next_pred * (1 - dones)
        loss = self.loss(q_pred, q_target).to(self.device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.train_counter += 1
        self.loss_history.append(loss.item())
        self.q_max_history.append(q_pred.max().item())
        self.q_min_history.append(q_pred.min().item())
        self.q_mean_history.append(q_pred.mean().item())
              
    class LegalAttentionNet(nn.Module):
        def __init__(self, input_dim, num_actions, d_model=512, num_heads=8, k_embedding_dim=16, num_layers=1, average_func='mlp'):
            super().__init__()
            assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
            self.num_actions = num_actions
            self.input_dim = input_dim
            self.d_model = d_model
            self.k_proj = nn.Linear(1 + k_embedding_dim, d_model)
            self.temperature1 = math.sqrt(d_model)
            self.d_head = d_model // num_heads
            self.num_heads = num_heads
            self.q_embedding = nn.Embedding(num_actions, d_model)
            self.register_buffer('action_indices', torch.arange(num_actions))
            self.head_mlp = nn.Linear(num_heads, 1, bias=False)
            nn.init.constant_(self.head_mlp.weight, 1.0)
            if num_layers == 0:
                raise NotImplementedError("learnable PE does not work, forbidden for now.")
                self.k_embedding = nn.Embedding(input_dim, k_embedding_dim).weight
            elif num_layers == 1:
                self.register_buffer("k_embedding", self.build_sincos_pe(input_dim, k_embedding_dim))
            else:
                self.register_buffer("k_embedding", self.build_sincos_pe_3d(input_dim, k_embedding_dim, num_layers))
            if average_func == 'mean':
                self.average_func = lambda x: torch.mean(x, dim=-1)
            elif average_func == 'mlp':
                self.average_func = lambda x: self.head_mlp(torch.relu(x)).squeeze(-1)

        def forward(self, x):
            # B: batch size, batch_size
            # S: length of the input, input_dim
            # A: number of actions, num_action
            # d_model: number of latent variables
            # h: number of heads, num_heads
            # d_head: number of latent variables for each head
            if x.dim() == 1:
                x = x.unsqueeze(0)                                                                          # inference
            batch_size = x.size(0)
            legal_mask = x[:, :self.num_actions]                                                            # (B, A)
            Q = self.q_embedding(self.action_indices).unsqueeze(0).expand(batch_size, -1, -1)               # (B, A, d_model)
            Q = Q * legal_mask.unsqueeze(-1)                                                                # (B, A, d_model)
            Q = Q.reshape(batch_size, self.num_actions, self.num_heads, self.d_head).transpose(1, 2)        # (B, h, A, d_head)
            K = x[:, self.num_actions:].unsqueeze(-1)                                                       # (B, S, 1)
            K_pos_vectors = self.k_embedding.repeat(batch_size, 1, 1)                                       # (B, S, k_embedding_dim)
            K = torch.cat([K, K_pos_vectors], dim=-1)                                                       # (B, S, 1 + k_embedding_dim)
            K = self.k_proj(K)                                                                              # (B, S, d_model)
            K = K.reshape(batch_size, self.num_heads, self.d_head, -1)                                      # (B, h, d_head, S)
            attn_scores = torch.matmul(Q, K) / self.temperature1                                            # (B, h, A, S)
            attn_scores = attn_scores.permute(0, 2, 3, 1)                                                   # (B, A, S, h)
            attn_weights = torch.relu(self.average_func(attn_scores))                                       # (B, A, S)
            V = x[:, self.num_actions:].unsqueeze(-1)                                                       # (B, S, 1)
            z = torch.relu(torch.matmul(attn_weights, V).squeeze(-1))                                       # (B, A)
            z = z + legal_mask * 1e-8
            q_val = torch.log(z)
            return q_val
        
        def get_attn_weights(self, x):
            # output shape must be (a, s)
            # shape(x) = (s, )
            legal_mask = x[:self.num_actions].unsqueeze(-1)                                                 # (A, 1)
            action_indices = torch.arange(self.num_actions, device=x.device)                                # (B, a)
            Q = self.q_embedding(action_indices)
            Q = Q * legal_mask
            Q = Q.reshape(self.num_actions, self.num_heads, self.d_head).permute(1, 0, 2)                   # (num_heads, a, d_head)
            K = x[self.num_actions:].unsqueeze(-1)                                                          # (s, 1)
            K_pos_vectors = self.k_embedding.squeeze(0)
            K = torch.cat([K, K_pos_vectors], dim=-1)                                                       # (s, 1 + k_embedding_dim)
            K = self.k_proj(K)                                                                              # (s, d_model)
            K = K.reshape(self.num_heads, self.d_head, -1)                                                  # (h, d_head, S)
            attn_scores = torch.matmul(Q, K) / self.temperature1                                            # (h, a, s)
            attn_scores = attn_scores.permute(1, 2, 0)                                                      # (a, s, h)
            attn_weights = torch.relu(self.average_func(attn_scores))                                       # (a, s)
            return attn_weights.detach().cpu().numpy()
        
        def build_sincos_pe(self, input_dim, d_pe):
            assert d_pe % 2 == 0, "d_pe must be even for sin+cos"
            pos = torch.arange(input_dim).unsqueeze(1)                      # (input_dim, 1)
            i = torch.arange(d_pe // 2).unsqueeze(0)                        # (1, d_pe//2)
            base = 2 * input_dim
            #base = 10000
            angle = pos / (base ** (2 * i / d_pe))                          # (input_dim, d_pe//2)
            pe = torch.zeros(input_dim, d_pe)
            pe[:, 0::2] = torch.sin(angle)
            pe[:, 1::2] = torch.cos(angle)
            return pe.unsqueeze(0)                                          # (1, input_dim, d_pe)
        
        def build_sincos_pe_3d(self, input_dim, d_pe, num_layers):
            # Assumes input is flattened in [layer-row-col] order, matches spatial position.
            assert d_pe > 4 , "check d_pe."
            assert (d_pe - 4) % 2 == 0 , "check d_pe."
            assert input_dim % num_layers == 0, "input_dim must be divisible by num_layer."
            layer_size = input_dim // num_layers
            grid_size_float = math.sqrt(layer_size)
            assert grid_size_float == int(grid_size_float), "Each layer's size must be a perfect square."
            grid_size = int(grid_size_float)

            d_layer = 4
            d_remain = d_pe - d_layer
            d_x = d_remain // 2
            d_y = d_remain // 2
            pe_z = self.build_sincos_pe(num_layers, d_layer).squeeze(0)   # (L, d_layer)
            pe_x = self.build_sincos_pe(grid_size, d_x).squeeze(0)        # (H, d_x)
            pe_y = self.build_sincos_pe(grid_size, d_y).squeeze(0)        # (W, d_y)
            pe_z = pe_z[:, None, None, :].expand(num_layers, grid_size, grid_size, d_layer)
            pe_x = pe_x[None, :, None, :].expand(num_layers, grid_size, grid_size, d_x)
            pe_y = pe_y[None, None, :, :].expand(num_layers, grid_size, grid_size, d_y)

            pe = torch.cat([pe_z, pe_x, pe_y], dim=-1)  # (L, H, W, d_pe)
            pe = pe.reshape(1, input_dim, d_pe)         # flatten to (1, input_dim, d_pe)
            return pe
