from tqdm import trange, tqdm
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torch.backends.cudnn as cudnn
import os
import glob
import cv2
from torch.utils.data.dataset import Dataset
import numpy as np
from shutil import copyfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torch.utils.data.dataset import ConcatDataset
from models.vggnet import vgg11
from dataset import loadEnvData

DATAPATH = 'path/to/pretext_training_model'

class oneHotSoundNetwork(nn.Module):
	
	def __init__(self, ):
		super(oneHotSoundNetwork, self).__init__()
		self.soundRNN = torch.nn.GRU(input_size=64 * 7, hidden_size=512, batch_first=True, bidirectional=True)
		self.soundCNN = nn.Sequential(
			nn.Conv2d(1, 64, (11, 11), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (1, 600, 40)->(32, 300, 20)
			nn.Conv2d(64, 64, (11, 5), stride=(2, 2), padding=(5, 5)), nn.ReLU(),  # (32, 300, 20)->(32, 150, 13)
			nn.Conv2d(64, 64, (7, 3), stride=(2, 2), padding=(1, 1)), nn.ReLU(),  # (32, 150, 13)->(32, 73, 7)
		)
		self.soundMLP = nn.Sequential(nn.Linear(2 * 512, 128), nn.ReLU(),
									  nn.Linear(128, 128), nn.ReLU(), )

		self.soundAux = nn.Sequential(
			nn.Linear(128, 64), nn.ReLU(),
			nn.Linear(64, 4),
		)

	def forward(self, sound):
		batchSize = list(sound.size())[0]
		cnn_out = self.soundCNN(sound)
		cnn_out = torch.reshape(torch.transpose(cnn_out, dim0=1, dim1=2), (-1, 73, 64 * 7))
		_, rnn_out = self.soundRNN(cnn_out)
		rnn_out = torch.cat((rnn_out[0, :, :], rnn_out[1, :, :]), dim=1)
		sound_feat = self.soundMLP(rnn_out)

		return self.soundAux(sound_feat)


def train_with_oneHot(device, dataPath, model_save_dir, batch_size, config):

	# load data
	data_generator, ds = loadEnvData(data_dir=dataPath, config=config, batch_size=batch_size, shuffle=True,
										 num_workers=4, drop_last=True)
	# get model
	model = oneHotSoundNetwork()

	# do some settings
	print("Using device:", device)
	model.cuda().train()

	cudnn.benchmark = True

	optimizer = optim.Adam(model.parameters(),
						   lr=1e-4,
						   weight_decay=1e-6)

	scheduler = MultiStepLR(optimizer, milestones=config.pretrainedLRDecayEpoch, gamma=config.pretrainedLRDecayGamma)
	criterion_sound = torch.nn.CrossEntropyLoss()

	# main training loop
	for ep in trange(30, position=0):

		loss_ep_sound = []
		for n_iter, (image, sound_positive, sound_negative, gt) in enumerate(data_generator):
			model.zero_grad()
			optimizer.zero_grad()

			sound_label=np.zeros((batch_size, 4))

			for j in range(batch_size):
				if gt[j]!=4:
					sound_label[j][gt[j]]=1

			sound_label = torch.from_numpy(sound_label).float().to(device)
			sound_pred=model(sound_positive.float().to(device))
			loss = criterion_sound(sound_pred, sound_label)
			loss.backward()
			optimizer.step()
			loss_ep_sound.append(loss.item())

		print('sound_loss', np.sum(loss_ep_sound) / len(loss_ep_sound))
		scheduler.step()

	print('Pretext Training Complete')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from cfg import config
train_with_oneHot(device, DATAPATH, '', 64, config)
