import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Flatten
import math


def buildSoundBranch(nn_module, RSI_ver, config=None):
	if RSI_ver==3:
		nn_module.rnn = torch.nn.GRU(input_size=128, hidden_size=512, batch_first=True, bidirectional=True)
		nn_module.cnn = nn.Sequential(
			nn.Conv2d(1, 64, (11, 40), stride=(2, 1), padding=(5, 0)), nn.ReLU(),  # (1, 600, 40)->(64, 300, 1)
			nn.Flatten(start_dim=-2),
			nn.Conv1d(64, 128, 3, stride=2, padding=1), nn.ReLU(),  # (64, 300)->(128, 150)
			nn.Conv1d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 150)->(128, 75)
		)
	elif RSI_ver==2:
		# RSI2 sound network
		nn_module.rnn = torch.nn.GRU(input_size=64 * 7, hidden_size=512, batch_first=True, bidirectional=True)
		nn_module.cnn = nn.Sequential(
			nn.Conv2d(1, 64, (11, 11), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (1, 600, 40)->(32, 300, 20)
			nn.Conv2d(64, 64, (11, 5), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (32, 300, 20)->(32, 150, 13)
			nn.Conv2d(64, 64, (7, 3), stride=(2, 2), padding=(1, 1)), nn.ReLU(),  # (32, 150, 13)->(32, 73, 7)
		)

	elif RSI_ver==1:
		nn_module.rnn = torch.nn.GRU(input_size=64 * 7, hidden_size=512, batch_first=True, bidirectional=True)
		nn_module.cnn = nn.Sequential(
			nn.Conv2d(1, 64, (11, 11), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (1, 600, 40)->(32, 300, 20)
			nn.Conv2d(64, 64, (11, 5), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (32, 300, 20)->(32, 150, 13)
			nn.Conv2d(64, 64, (7, 3), stride=(2, 2), padding=(1, 1)), nn.ReLU(),  # (32, 150, 13)->(32, 73, 7)
		)
		nn_module.fc = nn.Sequential(nn.Linear(2 * 512, 128), nn.ReLU(),
									 nn.Linear(128, 128), nn.ReLU(),
									 )

def buildCNN(nn_module, RSI_ver, config=None):
	modules = [
		nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
		nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
		nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
		nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
		nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
		nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
		nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
		nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
		nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
		nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
		Flatten()
	]

	nn_module.imgBranch = nn.Sequential(*modules)


def soundBranch(nn_module, RSI_ver, sound):
	if RSI_ver==3:
		cnn_out = nn_module.cnn(sound)
		cnn_out = torch.reshape(torch.transpose(cnn_out, dim0=1, dim1=2), (-1, 75, 128))
		_, rnn_out = nn_module.rnn(cnn_out)
		rnn_out = torch.cat((rnn_out[0, :, :], rnn_out[1, :, :]), dim=1)
		return rnn_out

	elif RSI_ver==2 or RSI_ver==1:
		cnn_out = nn_module.cnn(sound)
		cnn_out = torch.reshape(torch.transpose(cnn_out, dim0=1, dim1=2), (-1, 73, 64 * 7))
		_, rnn_out = nn_module.rnn(cnn_out)
		rnn_out = torch.cat((rnn_out[0, :, :], rnn_out[1, :, :]), dim=1)
		return rnn_out


class RSI3PretextNet(nn.Module):
	def __init__(self, config):
		super(RSI3PretextNet, self).__init__()
		self.config=config
		buildCNN(self, 3)

		self.zero_feat=torch.zeros((config.representationDim,)).cuda()

		self.imgSupCon=nn.Sequential(
			nn.Linear(128*9, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim)
		)

		buildSoundBranch(self, 3)

		self.cached_sound=None # the goal sound can be cached and be encoded only once

		self.soundSupCon = nn.Sequential(
			nn.Linear(2 * 512, 128), nn.ReLU(),
			nn.Linear(128, 64), nn.ReLU(),
			nn.Linear(64, config.representationDim)
		)

		if config.pretextEmptyCenter:
			self.imgBCE = nn.Sequential(
				nn.Linear(128 * 9, 128), nn.ReLU(),
				nn.Linear(128, 1)
			)

			self.soundBCE = nn.Sequential(
				nn.Linear(2 * 512, 128), nn.ReLU(),
				nn.Linear(128, 64), nn.ReLU(),
				nn.Linear(64, 1)
			)

	def forward(self, image, sound_positive, sound_negative, is_train=False):
		image_feat, image_feat_raw = None, None
		sound_feat_negative = None
		pos_sound_raw = None
		image_BCE, sound_BCE=None, None

		if image is not None:

			image_feat_raw=self.imgBranch(image[:, :3, :,:])
			image_feat=F.normalize(self.imgSupCon(image_feat_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				image_BCE=self.imgBCE(image_feat_raw).squeeze()
				if True:
					image_feat[torch.sigmoid(image_BCE)<0.5,:]=self.zero_feat

		# sound feat positive

		# at RL training and testing, we can use the cached sound encoding
		# assuming every env reset at the same time

		if not torch.isinf(sound_positive).all():
			pos_sound_raw=soundBranch(self, 3, sound_positive)
			sound_feat = F.normalize(self.soundSupCon(pos_sound_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				sound_BCE = self.soundBCE(pos_sound_raw).squeeze()
				if True:
					sound_feat[torch.sigmoid(sound_BCE)<0.5,:]=self.zero_feat
			self.cached_sound=sound_feat
		sound_feat_positive=self.cached_sound



		if sound_negative is not None: # during the VAR training, sound_negative is None for sure
			neg_rnn_out=soundBranch(self, 3, sound_negative)
			sound_feat_negative = F.normalize(self.soundSupCon(neg_rnn_out), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				sound_BCE = self.soundBCE(neg_rnn_out).squeeze()
				if True:
					sound_feat_negative[torch.sigmoid(sound_BCE) < 0.5, :] = self.zero_feat
		d = {'image_feat': image_feat, 'sound_feat_positive': sound_feat_positive,
			 'sound_feat_negative': sound_feat_negative, 'image_BCE': image_BCE, 'sound_BCE': sound_BCE,
			 'image_feat_raw': image_feat_raw, 'pos_sound_raw': pos_sound_raw
			 }
		return d


class RSI2PretextNet(nn.Module):
	def __init__(self, config):
		super(RSI2PretextNet, self).__init__()
		self.config=config

		self.zero_feat = torch.zeros((config.representationDim,)).cuda()
		buildCNN(self, 2, config)

		buildSoundBranch(self, 2, config)
		self.imgTriplet = nn.Sequential(
			nn.Linear(128 * 9, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim)
		)

		self.soundTriplet = nn.Sequential(nn.Linear(2 * 512, 128), nn.ReLU(),
										  nn.Linear(128, 64), nn.ReLU(),
										  nn.Linear(64, config.representationDim)
										  )

		if config.pretextEmptyCenter:
			self.imgBCE = nn.Sequential(
				nn.Linear(128 * 9, 128), nn.ReLU(),
				nn.Linear(128, 1)
			)



			self.soundBCE = nn.Sequential(
				nn.Linear(2 * 512, 128), nn.ReLU(),
				nn.Linear(128, 64), nn.ReLU(),
				nn.Linear(64, 1)
			)

		self.cached_sound=None # the goal sound can be cached and be encoded only once

	def forward(self, image, sound_positive, sound_negative, is_train=False):
		image_feat, image_feat_raw = None, None
		sound_feat_negative = None
		image_BCE, sound_BCE=None, None
		if image is not None:
			image_feat_raw = self.imgBranch(image[:, :3, :, :])
			image_feat = F.normalize(self.imgTriplet(image_feat_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				image_BCE=self.imgBCE(image_feat_raw).squeeze()
				if not is_train:
					image_feat[torch.sigmoid(image_BCE)<0.5,:]=self.zero_feat


		# sound feat positive

		# at RL training and testing, we can use the cached sound encoding
		# assuming every env reset at the same time
		pos_sound_raw = None
		if not torch.isinf(sound_positive).all():
			pos_sound_raw = soundBranch(self, 2, sound_positive)
			sound_feat = F.normalize(self.soundTriplet(pos_sound_raw), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				sound_BCE = self.soundBCE(pos_sound_raw).squeeze()
				if not is_train:
					sound_feat[torch.sigmoid(sound_BCE)<0.5,:]=self.zero_feat

			self.cached_sound = sound_feat
		sound_feat_positive=self.cached_sound

		if sound_negative is not None:
			neg_rnn_out = soundBranch(self, 2, sound_negative)
			sound_feat_negative = F.normalize(self.soundTriplet(neg_rnn_out), p=2, dim=1)
			if self.config.pretextEmptyCenter:
				if not is_train:
					sound_BCE = self.soundBCE(neg_rnn_out).squeeze()
					sound_feat_negative[torch.sigmoid(sound_BCE) < 0.5, :] = self.zero_feat

		d={'image_feat': image_feat, 'sound_feat_positive': sound_feat_positive,
		   'sound_feat_negative':sound_feat_negative, 'image_BCE':image_BCE, 'sound_BCE':sound_BCE,
		   'image_feat_raw':image_feat_raw, 'pos_sound_raw':pos_sound_raw
		   }

		return d

class representationImgLinear(torch.nn.Module):
	def __init__(self, config):
		super(representationImgLinear, self).__init__()
		self.config=config
		self.linear_image=torch.nn.Linear(128*9, config.taskNum+1)

	def forward(self, img_in):
		return self.linear_image(img_in)

class representationSoundLinear(torch.nn.Module):
	def __init__(self, config):
		super(representationSoundLinear, self).__init__()
		self.config=config
		self.linear_sound=torch.nn.Linear(1024, config.taskNum+1)

	def forward(self, sound_in):
		return self.linear_sound(sound_in)

class RSI1Pretrain(nn.Module):
	def __init__(self, config=None):
		super(RSI1Pretrain, self).__init__()
		self.base = RSI1PretrainBase(config)
		self.config=config
	def forward(self, img, sound):
		return self.base.forward(img, sound)

class RSI1PretrainBase(nn.Module):
	def __init__(self, config):
		super(RSI1PretrainBase, self).__init__()
		self.config=config
		self.encoderNumLayer = 1
		self.soundFrames = 600
		self.soundMFCCFeat = 40
		buildCNN(self, 1)
		buildSoundBranch(self, 1)

		self.inSight = nn.Sequential(
			nn.Linear(128 * 3 * 3, 128), nn.ReLU(),
			nn.Linear(128, self.config.taskNum),
		)

		self.soundAux = nn.Sequential(
			nn.Linear(128, 64), nn.ReLU(),
			nn.Linear(64, self.config.taskNum),
		)

		self.exi = nn.Sequential(
			nn.Linear(128 * 9, 256), nn.ReLU(),
			nn.Linear(256, 128), nn.ReLU(),
			nn.Linear(128, 1),
		)


	def forward(self, img, sound):
		batchSize = list(img.size())[0]
		img3=self.imgBranch(img)
		img3_flatten=torch.flatten(img3, start_dim=1)
		inSight=self.inSight(img3_flatten)

		rnn_out=soundBranch(self, 1, sound)
		sound_feat = self.fc(rnn_out)

		sound_aux = self.soundAux(sound_feat)
		context = torch.unsqueeze(sound_feat, -1)

		image_reshape = torch.reshape(img3, (batchSize, 128, -1))
		fusion = context + image_reshape

		fusion = torch.flatten(fusion, start_dim=1)
		exi = self.exi(fusion)
		return sound_aux, inSight, exi
