import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Flatten

def buildCNN(nn_module, RSI_ver, config=None):
	modules = [
		nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 48, 48)
		nn.Conv2d(32, 32, 3, stride=2, padding=1), nn.ReLU(),  # (32, 48, 48)->(32, 24, 24)
		nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),  # (32, 24, 24)->(64,12,12)
		nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(),  # (64, 12, 12)->(64, 6, 6)
		nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(),  # (64, 12, 12)->(64, 3, 3)
	]
	if RSI_ver==3:
		modules.append(Flatten())

	elif RSI_ver==2:
		modules=modules+[Flatten()]


	nn_module.imgBranch = nn.Sequential(*modules)

def buildSoundBranch(nn_module, RSI_ver, config=None):

	if RSI_ver==2 or RSI_ver==3:
		if config.sound_dim[1]==100:

			modules = [nn.Conv2d(1, 32, (5, 40), stride=(2, 1)), nn.ReLU(),  # (1, 100, 40)->(32, 48, 1)
					   nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 48, 1)->(32, 23, 1)
					   nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 23, 1)->(32, 11, 1)
					   nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 11, 1)->(32, 5, 1)
					   Flatten()
					   ]


			nn_module.soundBranch = nn.Sequential(*modules)
		elif config.sound_dim[1]==600:
			nn_module.soundBranch = nn.Sequential(
				nn.Conv2d(1, 32, (5, 40), stride=(2, 1)), nn.ReLU(),  # (1, 600, 40)->(32, 298, 1)
				nn.Conv2d(32, 32, (5, 1), stride=(2, 1)), nn.ReLU(),  # (32, 298, 1)->(32, 147, 1)
				nn.Conv2d(32, 32, (5, 1), stride=(2, 1)), nn.ReLU(),  # (32, 147, 1)->(32, 72, 1)
				nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 72, 1)->(32, 35, 1)
				nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 35, 1)->(32, 17, 1)
				nn.Conv2d(32, 32, (3, 1), stride=(2, 1)), nn.ReLU(),  # (32, 17, 1)->(32, 8, 1)
				Flatten(),
				nn.Linear(32*8, 128), nn.ReLU(),
				nn.Linear(128, config.representationDim)
			)

class RSI3PretextNet(nn.Module):
	def __init__(self, config):
		super(RSI3PretextNet, self).__init__()
		self.config=config
		buildCNN(self, 3, config)
		self.zero_feat=torch.zeros((config.representationDim,)).cuda()

		self.imgSupCon=nn.Sequential(
			nn.Linear(64*9, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim)
		)

		buildSoundBranch(self, 3, config)

		self.cached_sound=None # the goal sound can be cached and be encoded only once

		audio_cnn_dim=160 if config.sound_dim[1]==100 else 32*8
		self.soundSupCon = nn.Sequential(
			nn.Linear(audio_cnn_dim, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim)
		)

		if config.pretextEmptyCenter:
			self.imgBCE = nn.Sequential(
				nn.Linear(64 * 9, 128), nn.ReLU(),
				nn.Linear(128, 1)
			)

			self.soundBCE = nn.Sequential(
				nn.Linear(audio_cnn_dim, 128), nn.ReLU(),
				nn.Linear(128, 1)
			)


	def forward(self, image, sound_positive, sound_negative, is_train=False):
		image_feat, image_feat_raw = None, None
		sound_feat_negative = None
		pos_sound_raw = None
		image_BCE, sound_BCE = None, None

		if image is not None:
			image_feat_raw=self.imgBranch(image[:, :3, :,:])
			image_feat=F.normalize(self.imgSupCon(image_feat_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				image_BCE=self.imgBCE(image_feat_raw).squeeze()
				if not is_train:
					image_feat[torch.sigmoid(image_BCE)<0.5,:]=self.zero_feat

		# sound feat positive

		# at RL training and testing, we can use the cached sound encoding
		# assuming every env reset at the same time
		if not torch.isinf(sound_positive).all():
			pos_sound_raw=self.soundBranch(sound_positive)
			sound_feat = F.normalize(self.soundSupCon(pos_sound_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				sound_BCE = self.soundBCE(pos_sound_raw).squeeze()
				if not is_train:
					sound_feat[torch.sigmoid(sound_BCE)<0.5,:]=self.zero_feat
			self.cached_sound=sound_feat
		sound_feat_positive=self.cached_sound



		if sound_negative is not None:
			neg_sound_raw=self.soundBranch(sound_negative)
			sound_feat_negative = F.normalize(self.soundSupCon(neg_sound_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				sound_BCE = self.soundBCE(neg_sound_raw).squeeze()
				if not is_train:
					sound_feat_negative[torch.sigmoid(sound_BCE) < 0.5, :] = self.zero_feat
		d = {'image_feat': image_feat, 'sound_feat_positive': sound_feat_positive,
			 'sound_feat_negative': sound_feat_negative, 'image_BCE': image_BCE, 'sound_BCE': sound_BCE,
			 'image_feat_raw': image_feat_raw, 'pos_sound_raw': pos_sound_raw
			 }
		return d

class RSI2PretextNet(nn.Module):
	def __init__(self, config):
		super(RSI2PretextNet, self).__init__()
		self.config=config
		buildCNN(self, 2, config)
		buildSoundBranch(self, 2, config)

		self.imgTriplet=nn.Sequential(
			nn.Linear(64*9, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim)
		)

		audio_cnn_dim = 160 if config.sound_dim[1] == 100 else 32 * 8
		self.soundTriplet= nn.Sequential(
			nn.Linear(audio_cnn_dim, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim)
		)

	def forward(self, image, sound_positive, sound_negative):
		image_feat, image_feat_raw = None, None
		sound_feat_negative = None
		pos_sound_raw = None

		if image is not None:
			image_feat_raw = self.imgBranch(image[:, :3, :, :])
			image_feat=F.normalize(self.imgTriplet(image_feat_raw), p=2, dim=1)
		pos_sound_raw=self.soundBranch(sound_positive)
		sound_feat_positive=F.normalize(self.soundTriplet(pos_sound_raw), p=2, dim=1) 
		if sound_negative is not None:
			neg_sound_raw = self.soundBranch(sound_negative)
			sound_feat_negative=F.normalize(self.soundTriplet(neg_sound_raw), p=2, dim=1)

		d = {'image_feat': image_feat, 'sound_feat_positive': sound_feat_positive,
			 'sound_feat_negative': sound_feat_negative,
			 'image_feat_raw': image_feat_raw, 'pos_sound_raw': pos_sound_raw
			 }
		return d

class representationImgLinear(torch.nn.Module):
	def __init__(self, config):
		super(representationImgLinear, self).__init__()
		self.config=config
		self.linear_image=torch.nn.Linear(64*9, config.taskNum+1)

	def forward(self, img_in):
		return self.linear_image(img_in)

class representationSoundLinear(torch.nn.Module):
	def __init__(self, config):
		super(representationSoundLinear, self).__init__()
		self.config=config
		self.linear_sound=torch.nn.Linear(160, config.taskNum+1)


	def forward(self, sound_in):
		return self.linear_sound(sound_in)

class RSI1Pretrain(nn.Module):
	def __init__(self, config=None):
		super(RSI1Pretrain, self).__init__()
		self.base = RSI1PretrainBase()
	def forward(self, sound):
		return self.base.forward(sound)

class RSI1PretrainBase(nn.Module):
	def __init__(self):
		super(RSI1PretrainBase, self).__init__()

		self.soundBiLSTM = nn.LSTM(input_size=512, hidden_size=512,
								   num_layers=1, batch_first=True, bidirectional=True)

		self.soundMlp = nn.Sequential(
			nn.Linear(40, 977), nn.ReLU(),
			nn.Linear(977, 512), nn.ReLU(),
		)
		self.attnMlp = nn.Sequential(
			nn.Linear(2048, 512), nn.ReLU(),
			nn.Linear(512, 512), nn.ReLU(),

		)
		self.vA = nn.Linear(512, 1)
		self.contextMlp = nn.Sequential(
			nn.Linear(1024, 256), nn.ReLU(),
			nn.Linear(256, 128), nn.ReLU(),
		)

		self.soundAux = nn.Sequential(
			nn.Linear(128, 64), nn.ReLU(),
			nn.Linear(64, 4),
		)


	def forward(self, sound):
		sound = self.soundMlp(sound)

		# Forward propagate LSTM
		# out: tensor of shape (batch_size, seq_length, hidden_size*2)
		out, (h_n, c_n) = self.soundBiLSTM(sound)

		# attention
		h_n_concat = torch.cat([h_n[0], h_n[1]], dim=1)
		h_n_concat_expand = h_n_concat.unsqueeze(1).repeat([1, 100, 1])
		h_n_out_concat = torch.cat([out, h_n_concat_expand], dim=2)
		h_n_out_concat = self.attnMlp(h_n_out_concat)
		score = F.softmax(self.vA(h_n_out_concat), dim=1)
		context = torch.bmm(torch.transpose(out, dim0=1, dim1=2), score).squeeze()
		context = self.contextMlp(context)

		sound_aux = self.soundAux(context)

		return sound_aux