import pandas as pd
import pickle
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler

from models import MDRM

import torch
import torch.nn as nn
import numpy as np
import os

ARGS = {
    "num_epochs": 100
}

def make_data():
    with open("../data/ec/final/audio.pkl", "rb") as handle:
        audio = pickle.load(handle)
    
    with open("../data/ec/final/text.pkl", "rb") as handle:
        text = pickle.load(handle)
    
    with open("../data/ec/final/gender.pkl", "rb") as handle:
        gender = pickle.load(handle)
        gender = torch.tensor([0 if g == 'M' else 1 for g in gender])
    
    lens = [len(a) for a in audio]
    max_len = max(lens)
    new_audio, new_text = [], []
    for a in audio:
        while len(a) != max_len:
            a.append([1]*len(a[0]))
        new_audio.append(a)
    audio = torch.tensor(new_audio)
    audio = torch.nan_to_num(audio, 0)
    audio_mean = torch.mean(audio)
    audio_std = torch.std(audio)
    audio = (audio - audio_mean) / audio_std


    for t in text:
        while len(t) != max_len:
            t.append([1e-5]*len(t[0]))
        new_text.append(t)
    text = torch.tensor(new_text)

    df = pd.read_csv('../data/ec/final/vol.csv')
    df['label'] = 1
    df.loc[df['past_3'] > df['future_3'], 'label'] = 0
    y = df['label'].to_list()
    y_tensor = torch.zeros((len(text), 2))
    for i in range(len(y)):
        y_tensor[i][y[i]] = 1

    return audio, text, gender, y_tensor

master_audio, master_text, master_gender, master_y = make_data()
master_ids = [i for i in range(master_audio.shape[0])]

train_dataset = TensorDataset(master_ids, master_audio, master_texts, master_gender, master_y)

class MDRMUnit(nn.Module):
	def __init__(self, input_size, hidden_size):
		nn.Module.__init__(self)
		self.lstm1 = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
		self.dropout = nn.Dropout(0.5)
		self.linear = nn.Sequential(
			nn.Linear(2 * hidden_size, hidden_size),
			nn.ReLU()
		)
		self.lstm2 = nn.LSTM(hidden_size, hidden_size, dropout=0, bidirectional=True, batch_first=True)

	def forward(self, x):
		x, (h, c) = self.lstm1(x)
		x = self.dropout(x)
		x = self.linear(x)
		x, (h, c) = self.lstm2(x)

		return h

T = 0.70


def get_second(i, mat, g, all_genders):
	det = mat.detach().cpu().numpy()
	print(i)
	print(mat[i].detach().cpu().numpy())
	print(g)
	print(all_genders)
	num = np.amax(det[i])
	p = num * T
	args = np.argwhere(det[i] > p.item())
	lams = mat[i][args]

	if os.environ["TYPE"] != "orig":
		new_args = []
		lams = []
		for arg in list(args.squeeze()):
			if os.environ["TYPE"] == "same":
				if master_gender[arg] == g:
					new_args.append(arg)
					lams.append(mat[i][arg])
			elif os.environ["TYPE"] == "diff":
				if master_gender[arg] != g:
					new_args.append(arg)
					lams.append(mat[i][arg])
			else:
				print("Invalid run type")
				import sys
				sys.exit(1)
		args = np.array(new_args)
		args = args.reshape((len(args), 1))

		lams = np.array(lams)
		lams = lams.reshape((len(lams), 1))
	try:
		rand = np.random.randint(0, len(args), 1)
	except:
		print("Some Issue")
		rand = 0
	rand = rand[0]
	a = args[rand]
	l = lams[rand]
	return a, l

class MDRM(nn.Module):
	def __init__(self, audio_input, text_input, hidden_size, mat_path):
		nn.Module.__init__(self)
		self.mat = nn.Parameter(torch.tensor(np.load(mat_path)))
		self.unit_audio = MDRMUnit(audio_input, hidden_size)
		self.unit_text = MDRMUnit(text_input, hidden_size)
		self.linear = nn.Sequential(
			nn.Linear(hidden_size * 4, 64),
			nn.ReLU(),
			nn.Linear(64, 2)
		)

	def forward(self, audio, text, gender, y):
        if ids is not None:
			mixed_list, label_list = [], []
            combined = torch.cat((audio, text), dim=2)
            combined = torch.flatten(combined, 1)
			for id in ids:
				second, lam = get_second(id, self.mat, gender[id])
				second_audio, second_text = master_audio[second], master_text[second]
                second_y = master_y[second]
            
            audio = self.unit_audio(audio)
            text = self.unit_text(text)
            second_audio = self.unit_audio(second_audio)
            second_text = self.unit_text(second_text)

            cat = torch.cat((audio, text))
            cat = cat.permute(1, 0, 2)
            cat = torch.flatten(cat, 1)

            second_cat = torch.cat((second_audio, text))
            second_cat = second_cat.permute(1, 0, 2)
            second_cat = torch.flatten(cat, 1)

		all_cat = torch.cat((all_audio, all_text), dim=2)
		all_cat = torch.flatten(cat, 1)

		if ids is not None:
			mixed_list, label_list = [], []
			for id in ids:
				second, lam = get_second(id, self.mat, cur_genders[id], all_genders)
				print(second)
				print(lam)
				mixed = cat[id] * lam + [second[0]] * (1 - lam)
				label = y[id] * lam + y[second[0]] * (1 - lam)
				mixed_list.append(mixed.tolist())
				label_list.append(label)
		
			cat = torch.cat(mixed_list)
			y = torch.cat(label_list)
			cat = self.linear(cat)
			return cat, y
		else:
			cat = self.linear(cat)
			return cat

def train():
    num_epochs = ARGS["num_epochs"]
    model = MDRM(18, 768, 100, '../data/ec/final/train.npy').cuda()
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.BCELoss()
    train_dataloader, all_genders, all_audio, all_text, all_y = make_data()
    softmax = nn.Softmax(dim=1)
    for epoch in range(num_epochs):
        running_loss, total = 0, 0
        for idx, (ids, audio, text, y, gender) in enumerate(train_dataloader):
            audio = audio.cuda()
            text = text.cuda()
            y = y.cuda()
            gender = gender.cuda()

            output, y = model(audio, text, gender, all_genders, all_audio, all_text, all_y, y, ids)
            loss = criterion(softmax(output), y)
            optimizer.zero_grad()
            loss.backward()
            running_loss += loss.item()
            total += y.shape[0]
            optimizer.step()
            print(loss.item())
        
        running_loss /= total
        print(running_loss)


if __name__ == '__main__':
    train()