import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import ipdb as pdb
import numpy as np

from src.components.transformer_utils import LearnablePositionalEncoding
from src.components.transformer_encoder import Transformer




class TransformerCLF(nn.Module):
	def __init__(self, n_words, d_model, n_layer, n_head, n_out =1, dropout=0.0, pos_encode_type ='learnable'):
		super(TransformerCLF, self).__init__()
		self.model_type = 'SAN'
	
		self.pos_encoder = LearnablePositionalEncoding(d_model, dropout)


		self.name = f"mysan_d_model={d_model}_layer={n_layer}_head={n_head}"
		self.pos_encode = True
	
		self.d_model = d_model
		self.n_words = n_words
		d_ffn = 4*d_model
		
		# self._read_in = nn.Linear(n_dims, d_model)
		self.encoder= nn.Embedding(n_words, d_model)
		self._backbone= Transformer(d_model=d_model, n_layers=n_layer, num_heads= n_head, d_ffn=d_ffn)
		self._read_out = nn.Linear(d_model, n_out)

		print('My TransformerCLF Normal Training')

	# def init_weights(self):
	# 	initrange = 0.1
	# 	self._read_in.weight.data.uniform_(-initrange, initrange)
	# 	# if sels:
	# 	# 	self.decode()
	# 	self._read_out.weight.data.uniform_(-initrange, initrange)
	
	@staticmethod
	def _combine(xs_b, ys_b):
		"""Interleaves the x's and the y's into a single sequence."""
		bsize, points, dim = xs_b.shape
		ys_b_wide = torch.cat(
			(
				ys_b.view(bsize, points, 1),
				torch.zeros(bsize, points, dim - 1, device=ys_b.device),
			),
			axis=2,
		)
		zs = torch.stack((xs_b, ys_b_wide), dim=2)
		zs = zs.view(bsize, 2 * points, dim)
		return zs	

	def forward(self, x, masking=False):
		# input shape (x): (batch_size, length) [B L]
  
		embeds = self.encoder(x)

		embeds = embeds * math.sqrt(self.d_model)
		if self.pos_encode:
			embeds= self.pos_encoder(embeds)
		# embeds shape: (batch_size, seq_len, d_model)

		output = self._backbone(embeds, masking)
		prediction = self._read_out(output)
		# return prediction[:, -1, 0]  # predict only on xs

		if prediction.size(-1) == 1:
			return prediction[:, -1, 0]
		else:
			return prediction[:, -1]







	