import torch
import torch.nn as nn
import numpy as np
import os

class MDRMUnit(nn.Module):
	def __init__(self, input_size, hidden_size):
		nn.Module.__init__(self)
		self.lstm1 = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
		self.dropout = nn.Dropout(0.5)
		self.linear = nn.Sequential(
			nn.Linear(2 * hidden_size, hidden_size),
			nn.ReLU()
		)
		self.lstm2 = nn.LSTM(hidden_size, hidden_size, dropout=0, bidirectional=True, batch_first=True)

	def forward(self, x):
		x, (h, c) = self.lstm1(x)
		x = self.dropout(x)
		x = self.linear(x)
		x, (h, c) = self.lstm2(x)

		return h

T = 0.70


def get_second(i, mat, g, all_genders):
	det = mat.detach().cpu().numpy()
	print(i)
	print(mat[i].detach().cpu().numpy())
	print(g)
	print(all_genders)
	num = np.amax(det[i])
	p = num * T
	args = np.argwhere(det[i] > p.item())
	lams = mat[i][args]

	if os.environ["TYPE"] != "orig":
		new_args = []
		lams = []
		for arg in list(args.squeeze()):
			if os.environ["TYPE"] == "same":
				if all_genders[arg] == g:
					new_args.append(arg)
					lams.append(mat[i][arg])
			elif os.environ["TYPE"] == "diff":
				if all_genders[arg] != g:
					new_args.append(arg)
					lams.append(mat[i][arg])
			else:
				print("Invalid run type")
				import sys
				sys.exit(1)
		args = np.array(new_args)
		args = args.reshape((len(args), 1))

		lams = np.array(lams)
		lams = lams.reshape((len(lams), 1))
	try:
		rand = np.random.randint(0, len(args), 1)
	except:
		print("Some Issue")
		rand = 0
	rand = rand[0]
	a = args[rand]
	l = lams[rand]
	return a, l

class MDRM(nn.Module):
	def __init__(self, audio_input, text_input, hidden_size, mat_path):
		nn.Module.__init__(self)
		self.mat = nn.Parameter(torch.tensor(np.load(mat_path)))
		self.unit_audio = MDRMUnit(audio_input, hidden_size)
		self.unit_text = MDRMUnit(text_input, hidden_size)
		self.linear = nn.Sequential(
			nn.Linear(hidden_size * 4, 64),
			nn.ReLU(),
			nn.Linear(64, 2)
		)

	def forward(self, audio, text, cur_genders, all_genders, all_audio, all_text, all_y, y=None, ids=None):
		audio = self.unit_audio(audio)
		text = self.unit_text(text)

		cat = torch.cat((audio, text))
		cat = cat.permute(1, 0, 2)
		cat = torch.flatten(cat, 1)

		all_cat = torch.cat((all_audio, all_text), dim=2)
		all_cat = torch.flatten(cat, 1)

		if ids is not None:
			mixed_list, label_list = [], []
			for id in ids:
				second, lam = get_second(id, self.mat, cur_genders[id], all_genders)
				print(second)
				print(lam)
				mixed = cat[id] * lam + cat[second[0]] * (1 - lam)
				label = y[id] * lam + y[second[0]] * (1 - lam)
				mixed_list.append(mixed.tolist())
				label_list.append(label)
		
			cat = torch.cat(mixed_list)
			y = torch.cat(label_list)
			cat = self.linear(cat)
			return cat, y
		else:
			cat = self.linear(cat)
			return cat


if __name__ == '__main__':
	audio = torch.rand(32, 10, 18)
	text = torch.rand(32, 10, 768)
	model = MDRM(18, 768, 100)
	print(model(audio, text).shape)
