from tqdm import trange, tqdm
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torch.backends.cudnn as cudnn
import os
import glob
import cv2
from torch.utils.data.dataset import Dataset
import numpy as np
from shutil import copyfile
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torch.utils.data.dataset import ConcatDataset
from models.vggnet import vgg11
from dataset import loadEnvData


class Flatten(nn.Module):
	def forward(self, x):
		return x.view(x.size(0), -1)


class oneHotImgNetwork(nn.Module):

	def __init__(self, train_exi, train_sound):
		super(oneHotImgNetwork, self).__init__()
		self.train_exi = train_exi

		self.imgCNN = nn.Sequential(
			nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
			nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
			nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
			nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
			nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
			nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
			nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
			nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
			nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
			nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
			Flatten()
		)

		self.inSight=nn.Sequential(
			nn.Linear(128*3*3, 128), nn.ReLU(),
			nn.Linear(128, 4),
		)

		if train_exi:
			self.exi=nn.Sequential(
				nn.Linear(128*9, 256), nn.ReLU(),
				nn.Linear(256, 128), nn.ReLU(),
				nn.Linear(128, 1),
			)

			self.soundMlp = nn.Sequential(
				nn.Linear(4, 64), nn.ReLU(),
				nn.Linear(64, 128), nn.ReLU(),
				nn.Linear(128, 128), nn.ReLU(),
			)

	def forward(self, img, sound):
		batchSize = list(img.size())[0]

		CNN_feat=self.imgCNN(img)
		inSight=self.inSight(CNN_feat)
		if self.train_exi:
			sound_feat = self.soundMlp(sound)
			sound_feat = torch.unsqueeze(sound_feat, -1)
			image_reshape = torch.reshape(CNN_feat, (batchSize, 128, -1))
			fusion = sound_feat + image_reshape
			fusion = torch.flatten(fusion, start_dim=1)
			exi=self.exi(fusion)[:,0]
			return inSight, exi
		else:
			return inSight


def train_with_oneHot(device, dataPath, model_save_dir, batch_size, config):

	train_exi=True
	# load data
	data_generator, ds = loadEnvData(data_dir=dataPath, config=config, batch_size=batch_size, shuffle=True,
										 num_workers=4, drop_last=True)
	# get model
	model = oneHotImgNetwork(train_exi)

	# do some settings
	print("Using device:", device)
	model.cuda().train()

	cudnn.benchmark = True

	optimizer = optim.Adam(model.parameters(),
						   lr=0.0004,
						   weight_decay=1e-6)

	scheduler = MultiStepLR(optimizer, milestones=config.pretrainedLRDecayEpoch, gamma=config.pretrainedLRDecayGamma)
	criterion_insight = torch.nn.BCEWithLogitsLoss()
	criterion_exi = torch.nn.BCEWithLogitsLoss()

	# main training loop
	for ep in trange(30, position=0):

		loss_ep_inSight = []
		loss_ep_exi = []
		for n_iter, (image, sound_positive, sound_negative, gt) in enumerate(data_generator):
			model.zero_grad()
			optimizer.zero_grad()

			# randomly generate sound_labels
			inSight = np.zeros((batch_size, 4))
			sound_command = np.zeros((batch_size, 4))
			for j in range(batch_size):
				if gt[j]!=4:
					inSight[j,gt[j]]=1
					# randomly generate sound_command
					if np.random.rand()>0.5: # exist
						sound_command[j][gt[j]]=1
					else:
						prob=np.ones((4,))/3
						prob[gt[j]]=0
						sound_command[j][np.random.choice(4, p=prob)]=1
				else:
					sound_command[j][np.random.choice(4)]=1

			sound_command = torch.from_numpy(sound_command).float().to(device)
			inSight = torch.from_numpy(inSight).float().to(device)
			exi_label = torch.sum(sound_command * inSight, dim=1).to(device)

			inSight_pred, exi_pred=model(image.to(device), sound_command)

			loss1 = criterion_insight(inSight_pred, inSight)
			loss2 = criterion_exi(exi_pred, exi_label)
			loss = loss1 + loss2
			loss.backward()
			optimizer.step()

			loss_ep_inSight.append(loss1.item())
			loss_ep_exi.append(loss2.item())

		print('inSight_loss', np.sum(loss_ep_inSight) / len(loss_ep_inSight))
		print('exi_loss', np.sum(loss_ep_exi) / len(loss_ep_exi))
		scheduler.step()

		if (ep + 1) % 10 == 0:
			fname = os.path.join(model_save_dir, str(ep) + '.pt')
			if not os.path.exists(model_save_dir):
				os.makedirs(model_save_dir)

			torch.save(model.state_dict(), fname)
			print('Model saved to ' + fname)
	print('Pretext Training Complete')


class repImgNetwork(nn.Module):

	def __init__(self, config):
		super(repImgNetwork, self).__init__()
		self.config=config

		self.buildCNN()

		self.cnnMlp = nn.Sequential(
			nn.Linear(128 * 3 * 3, 512), nn.ReLU(),
			nn.Linear(512, 256), nn.ReLU())

		if config.pretrainedUseExi:
			self.exi=nn.Sequential(
				nn.Linear(256, 256), nn.ReLU(),
				nn.Linear(256, 128), nn.ReLU(),
				nn.Linear(128, 1),
			)

			self.soundMlp = nn.Sequential(
				nn.Linear(3, 128), nn.ReLU(),
				nn.Linear(128, 256), nn.ReLU(),
				nn.Linear(256, 256), nn.ReLU(),
			)

		self.inSight = nn.Sequential(
			nn.Linear(256, 128), nn.ReLU(),
			nn.Linear(128, config.representationDim),
		)

	def buildCNN(self):
		if self.config.name=='TurtleBotConfig':
			self.imgCNN=vgg11(pretrained=False, progress=False)
		else:
			self.imgCNN = nn.Sequential(
				nn.Conv2d(3, 32, 3, stride=1, padding=1), nn.ReLU(),  # (3, 96, 96)->(32, 96, 96)
				nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.ReLU(),
				nn.MaxPool2d(2, stride=2),  # (32, 96, 96)->(32, 48, 48)
				nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.ReLU(),  # (32, 48, 48)->(64, 48, 48)
				nn.MaxPool2d(2, stride=2),  # (64, 48, 48)->(64, 24, 24)
				nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(),  # (64, 24, 24)->(64, 24, 24)
				nn.MaxPool2d(2, stride=2),  # (64, 24, 24))->(64, 12, 12)
				nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(),  # (64, 12, 12)->(128, 12, 12)
				nn.MaxPool2d(2, stride=2),  # (128, 12, 12)->(128, 6, 6)
				nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.ReLU(),  # (128, 6, 6)->(128, 3, 3)
				Flatten()
			)

	def forward(self, img, sound):

		image = self.imgCNN(img)
		image_flatten=self.cnnMlp(image)
		inSight = F.normalize(self.inSight(image_flatten), p=2, dim=1)

		exi=None
		if self.config.pretrainedUseExi:
			sound=self.soundMlp(sound)
			fusion = sound + image_flatten
			exi=self.exi(fusion)[:,0]

		return exi, inSight


def train_with_representation(device, pretextModel, dataPath, model_save_dir, batch_size, config):

	# load data
	data_generator, ds = loadEnvData(data_dir=dataPath, config=config, batch_size=batch_size, shuffle=True,
										 num_workers=4, drop_last=True, dtype=config.pretextDataset)
	# get model
	model = repImgNetwork(config)

	# do some settings
	print("Begin training pretrain network")
	model.cuda().train()
	optimizer = optim.Adam(model.parameters(),
						   lr=config.pretrainedLR,
						   weight_decay=1e-6)

	scheduler = MultiStepLR(optimizer, milestones=config.pretrainedLRDecayEpoch, gamma=config.pretrainedLRDecayGamma)
	criterion_exi = torch.nn.BCEWithLogitsLoss()
	criterion_inSight = torch.nn.MSELoss()

	# main training loop
	for ep in trange(config.pretrainedEpoch, position=0):
		loss_ep_exi = []
		loss_ep_inSight=[]
		for n, data in enumerate(data_generator):
			if config.RSI_ver == 2:
				img, sp, sn, gt = data[0], data[1], data[2], data[3]

			elif config.RSI_ver == 3:
				img, sp, gt = data[0][0], data[1][0], data[2]

			else:
				raise NotImplementedError
			model.zero_grad()
			optimizer.zero_grad()

			sound=np.zeros_like(sp)
			if config.pretrainedUseExi:
				label=np.zeros((batch_size,))
				# randomly generate sound_label
				for j in range(batch_size):
					if torch.rand(1)>0.5: # choose sound_positive and label is 1
						sound[j]=sp[j]
						label[j]=1.
					else:
						if config.RSI_ver==2:
							sound[j] = sn[j]
						elif config.RSI_ver==3:
							idxs=torch.where(gt!=gt[j])[0]
							idx=torch.randint(0, len(idxs), size=())
							sound[j]=sp[idxs[idx]]
						else:
							raise NotImplementedError
						label[j] = 0.
			sound = torch.from_numpy(sound).float()
			with torch.no_grad():
				image_feat, sound_feat_positive,_, _,_ = pretextModel(img.to(device),
																	sound.to(device),
																	None)

			exi_pred, inSight_pred=model(img.to(device), sound_feat_positive.to(device))
			inSight_loss= criterion_inSight(inSight_pred, image_feat)
			loss_ep_inSight.append(inSight_loss.item())

			if config.pretrainedUseExi:
				exi_label = torch.from_numpy(label).float().to(device)
				exi_loss=criterion_exi(exi_pred, exi_label)
				loss_ep_exi.append(exi_loss.item())
				loss = inSight_loss + exi_loss
			else:
				loss=inSight_loss

			loss.backward()
			optimizer.step()

		if config.pretrainedUseExi:
			print('exi_loss', np.sum(loss_ep_exi) / len(loss_ep_exi))
		print('inSight_loss', np.sum(loss_ep_inSight) / len(loss_ep_inSight))

		scheduler.step()

		if (ep + 1) % 10 == 0:
			fname = os.path.join(model_save_dir, str(ep) + '.pt')
			if not os.path.exists(model_save_dir):
				os.makedirs(model_save_dir)

			torch.save(model.state_dict(), fname)
			print('Model saved to ' + fname)
			
	print('Pretrained Network Training Complete')
