# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.nn.attention import SDPBackend, sdpa_kernel
from torchaudio.models.wav2vec2 import Wav2Vec2Model
from torchaudio.models.wav2vec2.components import _get_feature_extractor,_get_padding_mask,_compute_mask_indices,\
	FeatureProjection,ConvolutionalPositionalEmbedding,SelfAttention,FeedForward,EncoderLayer,Encoder,Transformer
from torchaudio.models.wav2vec2.components import MaskGenerator as _MaskGenerator

class Wav2Vec2Pretrainer(Wav2Vec2Model):
	def __init__(self, out_channels, embed_dim, prenorm=True, dual_codebook=False, **encoder_config):
		extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2
		feature_extractor = _get_feature_extractor(
						norm_mode='group_norm',
						shapes=extractor_conv_layer_config,
						bias=False)
		encoder = _get_encoder(
							in_features=extractor_conv_layer_config[-1][0],
							embed_dim=embed_dim,
							layer_norm_first=prenorm,
							**encoder_config)
		aux = nn.Linear(embed_dim, out_channels*(int(dual_codebook)+1))
		super().__init__(feature_extractor, encoder, aux)
		self.proj_feature = nn.Linear(extractor_conv_layer_config[-1][0], out_channels)
		self.mask_generator = MaskGenerator(
			encoder_embed_dim=extractor_conv_layer_config[-1][0],
			mask_prob=0.065,
			mask_selection='static',
			mask_other=0.0,
			mask_length=10,
			no_mask_overlap=False,
			mask_min_space=1, # NOTE: not used when no_mask_overlap=False
			mask_channel_prob=0.0, # NOTE: No channel dropout during pretraining
			mask_channel_selection='static',
			mask_channel_other=0.0,
			mask_channel_length=64, # NOTE: Used for fine-tuning
			no_mask_channel_overlap=False,
			mask_channel_min_space=1, # NOTE: not used when no_mask_channel_overlap=False
		)

	def forward(self, waveforms, lengths=None):
		"""
		waveforms: batch_size x max_length
		lengths: batch_size
		"""
		feature,lengths = self.feature_extractor(waveforms, lengths)
		padding_mask = None if lengths is None \
						else _get_padding_mask(feature, lengths)
		masked,is_target = self.mask_generator(feature,padding_mask)
		encoded = self.encoder(masked, lengths)
		encoded = self.aux(encoded)
		feature = self.proj_feature(feature)
		return feature,encoded,is_target,lengths
		
class Wav2Vec2FineTuner(Wav2Vec2Pretrainer):
	def __init__(self, vocab_size, mask_prob, mask_channel_prob, num_frozen_encoder_layers=0,**kwargs):
		super().__init__(**kwargs)
		self.aux = nn.Linear(self.aux.in_features, vocab_size)
		# NOTE: LayerNorm returns NaN in some settings during finetuning...
		# self.encoder.feature_projection.layer_norm \
		# 	= Fp32LayerNorm(
		# 		self.encoder.feature_projection.layer_norm.normalized_shape,
		# 		self.encoder.feature_projection.layer_norm.eps,
		# 		self.encoder.feature_projection.layer_norm.elementwise_affine,
		# 		not self.encoder.feature_projection.layer_norm.bias is None,
		# 		)
		# self.encoder.transformer.layer_norm \
		# 	= Fp32LayerNorm(
		# 		self.encoder.transformer.layer_norm.normalized_shape,
		# 		self.encoder.transformer.layer_norm.eps,
		# 		self.encoder.transformer.layer_norm.elementwise_affine,
		# 		not self.encoder.transformer.layer_norm.bias is None,
		# 		)
		# for layer in self.encoder.transformer.layers:
		# 	layer.layer_norm \
		# 		= Fp32LayerNorm(
		# 			layer.layer_norm.normalized_shape,
		# 			layer.layer_norm.eps,
		# 			layer.layer_norm.elementwise_affine,
		# 			layer.layer_norm.bias is None,
		# 			)
		# 	layer.final_layer_norm \
		# 		= Fp32LayerNorm(
		# 			layer.final_layer_norm.normalized_shape,
		# 			layer.final_layer_norm.eps,
		# 			layer.final_layer_norm.elementwise_affine,
		# 			not layer.final_layer_norm.bias is None,
		# 			)
		del self.proj_feature # NOTE: Unused parameters must be removed for DDP to work.
		self.mask_generator.mask_prob = mask_prob
		self.mask_generator.mask_channel_prob = mask_channel_prob
		for p in self.feature_extractor.parameters():
			p.requires_grad = False
		for name,p in self.encoder.named_parameters():
			for l in range(num_frozen_encoder_layers):
				if 'transformer.layers.{}'.format(l) in name:
					p.requires_grad = False
		# for p in self.proj_feature.parameters():
		# 	p.requires_grad = False

	def _convs(self, x, lengths = None):
		x = x.unsqueeze(1)  # (batch, channel==1, frame)
		for l,layer in enumerate(self.feature_extractor.conv_layers):
			# x, lengths = layer(x, lengths)  # (batch, feature, frame)
			x = layer.conv(x)
			# x = x.clamp(max=5.0)
			print('{}-th conv'.format(l), x.min())
			if layer.layer_norm is not None:
				x = layer.layer_norm(x)
				print('{}-th layer_norm'.format(l), x.min())
			x = nn.functional.gelu(x)
			# print('{}-th gelu'.format(l), x.max())

			if lengths is not None:
				lengths = torch.div(lengths - layer.kernel_size, layer.stride, rounding_mode="floor") + 1
				# When input length is 0, the resulting length can be negative. So fix it here.
				lengths = torch.max(torch.zeros_like(lengths), lengths)
		x = x.transpose(1, 2)  # (batch, frame, feature)
		return x, lengths

	def forward(self, waveforms, lengths=None):
		# with torch.no_grad():
		# L = waveforms.size(1)
		# pad_mask = torch.arange(L, device=lengths.device)>=lengths[:,None]
		# waveforms = torch.where(pad_mask, torch.rand_like(waveforms), waveforms)
		# feature,lengths = self._convs(waveforms, lengths)
		feature,lengths = self.feature_extractor(waveforms, lengths)
		# feature = feature.clone() # NOTE: This sets feature.requires_grad=True. See https://stackoverflow.com/a/71224759
		# print('feature',feature.isnan().float().mean(),feature.isinf().float().mean())


		padding_mask = None if lengths is None \
						else _get_padding_mask(feature, lengths)
		masked,_ = self.mask_generator(feature,padding_mask)
		# with sdpa_kernel(SDPBackend.MATH):
		encoded = self.encoder(masked, lengths)
		# print('encoded',encoded.isnan().float().mean(),encoded.isinf().float().mean())

		# encoded = F.normalize(encoded, p=2.0, dim=-1) # NOTE: Pretraining normalize it so the amplitude would not be important.
		# encoded = encoded*self.log_encoded_scalar.exp()
		# print(encoded.max(),encoded.min())

		L = feature.size(1)
		pad_mask = torch.arange(L, device=lengths.device)>=lengths[:,None]
		lower_mask = torch.tril(torch.ones((L,L), device=lengths.device)).bool()
		combined_mask = pad_mask.unsqueeze(1) | lower_mask
		encoded_diversity = torch.bmm(encoded, encoded.transpose(1,2)
								).masked_fill(combined_mask, 0.0
								).sum(dim=(1,2)) / (~combined_mask).float().sum(dim=(1,2))

		logits = self.aux(encoded)
		return logits,lengths,encoded_diversity

class ContrastiveLoss(nn.Module):
	def __init__(self, num_distractors, temperature=0.1, **kwargs):
		super().__init__()
		self.num_distractors = num_distractors
		self.temperature = temperature

	def forward(self, encoded, quantized, is_target):
		L = encoded.size(1)
		encoded = F.normalize(encoded, p=2.0, dim=-1)
		quantized = F.normalize(quantized, p=2.0, dim=-1)
		similarity = torch.bmm(encoded, quantized.transpose(-1,-2)) / self.temperature # BxLxL
		log_target = similarity.diagonal(dim1=-2,dim2=-1)
		sample_seed = torch.rand_like(similarity) + torch.eye(L,
														device=similarity.device) # NOTE: always choose the diagonal
		_,sample_idxs = sample_seed.topk(k=self.num_distractors+1,dim=-1,largest=True)
		log_samples = similarity.take_along_dim(sample_idxs,dim=-1).logsumexp(dim=-1)
		loss = log_samples-log_target
		loss = loss.masked_fill(~is_target,0.0).sum(dim=1)/is_target.float().sum(dim=1)
		return loss.mean()
	
def _get_encoder(
		in_features: int,
		embed_dim: int,
		dropout_input: float,
		pos_conv_kernel: int,
		pos_conv_groups: int,
		num_layers: int,
		num_heads: int,
		attention_dropout: float,
		ff_interm_features: int,
		ff_interm_dropout: float,
		dropout: float,
		layer_norm_first: bool,
		layer_drop: float,
	):
	"""
	Customized torchaudio.models.wav2vec2.components._get_encoder for using a DDP-compatible layer-dropped Transformer.
	"""
	feature_projection = FeatureProjection(in_features, embed_dim, dropout_input)
	pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups)

	# Original impl
	# https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782
	encoder_layers = nn.ModuleList()
	for _ in range(num_layers):
		attention = SelfAttention(
			embed_dim=embed_dim,
			num_heads=num_heads,
			dropout=attention_dropout,
		)
		feed_forward = FeedForward(
			io_features=embed_dim,
			intermediate_features=ff_interm_features,
			intermediate_dropout=ff_interm_dropout,
			output_dropout=dropout,
		)
		encoder_layers.append(
			EncoderLayer(
				attention=attention,
				dropout=dropout,
				layer_norm_first=layer_norm_first,
				feed_forward=feed_forward,
			)
		)
	transformer = DDPCompatibleTransformer(
		pos_conv_embed=pos_conv,
		dropout=dropout,
		layers=encoder_layers,
		layer_norm_first=not layer_norm_first,
		layer_drop=layer_drop,
	)
	return Encoder(feature_projection, transformer)

class DDPCompatibleTransformer(Transformer):
	def forward(self,x, attention_mask=None,position_bias=None):
		x = self._preprocess(x)
		for layer in self.layers:
			# NOTE: For compatiblity w/ DDP training, layer-drop is implemented as a weighted sum w/ previous output
			x_new, position_bias_new = layer(x, attention_mask, position_bias=position_bias)
			skip = float(self.training and torch.rand(1).item() <= self.layer_drop)
			x = skip*x + (1-skip)*x_new
			if not position_bias is None:
				position_bias = skip*position_bias + (1-skip)*position_bias_new
			# if not (self.training and torch.rand(1).item() <= self.layer_drop):
				# x, position_bias = layer(x, attention_mask, position_bias=position_bias)

		if not self.layer_norm_first:
			x = self.layer_norm(x)
		return x
	
class MaskGenerator(_MaskGenerator):
	# NOTE: Fix a potential error for low mask_channel_prob.
	def forward(self, x, padding_mask):
		"""
		Args:
			x (Tensor): The encoded representations after feature extraction module.
			padding_mask (Tensor or None): The padding mask of the same dimension as shape,
				which will prevent masking padded elements.

		Returns:
			Tensor: The feature representations after masking.
			Tensor: The generated mask indices.
		"""
		B, T, C = x.shape
		if self.mask_prob > 0:
			mask_indices = _compute_mask_indices(
				(B, T),
				padding_mask,
				self.mask_prob,
				self.mask_length,
				self.mask_selection,
				self.mask_other,
				min_masks=2,
				no_overlap=self.no_mask_overlap,
				min_space=self.mask_min_space,
			)
			mask_indices = mask_indices.to(x.device)
			# change dtype of mask_embedding to x for mixed-precision training.
			# see https://github.com/pytorch/audio/issues/2847 for details.
			x[mask_indices] = self.mask_embedding.to(x.dtype)
		else:
			mask_indices = None

		if self.mask_channel_prob > 0:
			mask_channel_indices = _compute_mask_indices(
				(B, C),
				None,
				self.mask_channel_prob,
				self.mask_channel_length,
				self.mask_channel_selection,
				self.mask_channel_other,
				min_masks=1, # <- NOTE: This is the fix. Default value (=0) can raise an unexpected error.
				no_overlap=self.no_mask_channel_overlap,
				min_space=self.mask_channel_min_space,
			)
			mask_channel_indices = mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
			x[mask_channel_indices] = 0

		return x, mask_indices
	
class Fp32LayerNorm(nn.LayerNorm):
	"""
	Borrowed from https://docs.pytorch.org/torchtune/stable/generated/torchtune.modules.Fp32LayerNorm.html
	"""
	def forward(self, x: torch.Tensor) -> torch.Tensor:
		"""
		Args:
			x (torch.Tensor): Input tensor.

		Returns:
			torch.Tensor: The normalized output tensor having the same shape as ``x``.
		"""
		output = F.layer_norm(
			x.float(),
			self.normalized_shape,
			self.weight.float() if self.weight is not None else None,
			self.bias.float() if self.bias is not None else None,
			self.eps,
		)
		return output.type_as(x)