import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch

import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.utils.data as Data

class NewIBBase(nn.Module):#BaseVAE):


	def __init__(self,
				input_shape: int,
				output_shape: int,
				latent_dim: int,
				hidden_dims=None,
				**kwargs) -> None:
		super(NewIBBase, self).__init__()

		self.latent_dim = latent_dim
		self.input_shape=input_shape
		self.output_shape=output_shape
		
		self.embed_dim=16

		modules = []
		if hidden_dims is None:
			hidden_dims = [32, 32]

		# Build Encoder
		in_channels=input_shape
		for h_dim in hidden_dims:
			modules.append(
				nn.Sequential(
					nn.Linear(in_channels,h_dim),
					nn.ReLU())
			)
			in_channels = h_dim

		self.encoder = nn.Sequential(*modules)
		self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
		self.fc_var = nn.Linear(hidden_dims[-1], latent_dim)
		
		#encoder 2
		modules = []
		# Build Encoder
		in_channels=input_shape+output_shape
		for h_dim in hidden_dims:
			modules.append(
				nn.Sequential(
					nn.Linear(in_channels,h_dim),
					nn.ReLU())
			)
			in_channels = h_dim

		self.encoder_2 = nn.Sequential(*modules)
		self.fc_mu_2 = nn.Linear(hidden_dims[-1], latent_dim)
		self.fc_var_2 = nn.Linear(hidden_dims[-1], latent_dim)


		# Build Decoder
		modules = []

		self.decoder_input = nn.Linear(latent_dim*3, hidden_dims[-1])

		hidden_dims.reverse()

		for i in range(len(hidden_dims) - 1):
			modules.append(
				nn.Sequential(
					nn.Linear(hidden_dims[i],hidden_dims[i + 1]),
					nn.ReLU())
			)



		self.decoder = nn.Sequential(*modules)
		
		self.final_layer=nn.Linear(hidden_dims[-1],output_shape)

		#decoder_adv_x
		modules = []

		self.decoder_input_adv_x = nn.Linear(latent_dim, hidden_dims[-1])

		hidden_dims.reverse()

		for i in range(len(hidden_dims) - 1):
			modules.append(
				nn.Sequential(
					nn.Linear(hidden_dims[i],hidden_dims[i + 1]),
					nn.ReLU())
			)



		self.decoder_adv_x = nn.Sequential(*modules)
		
		self.final_layer_adv_x=nn.Linear(hidden_dims[-1],input_shape)
		
		#decoder_adv_y
		modules = []

		self.decoder_input_adv_y = nn.Linear(latent_dim, hidden_dims[-1])

		hidden_dims.reverse()

		for i in range(len(hidden_dims) - 1):
			modules.append(
				nn.Sequential(
					nn.Linear(hidden_dims[i],hidden_dims[i + 1]),
					nn.ReLU())
			)



		self.decoder_adv_y = nn.Sequential(*modules)
		
		self.final_layer_adv_y=nn.Linear(hidden_dims[-1],output_shape)
		
		self.build_optimizers()
	
	def build_optimizers(self):
		self.main_param=list(self.encoder.parameters())+list(self.fc_mu.parameters())+list(self.fc_var.parameters())
		self.main_param+=list(self.encoder_2.parameters())+list(self.fc_mu_2.parameters())+list(self.fc_var_2.parameters())
		self.main_param+=list(self.decoder_input.parameters())+list(self.decoder.parameters())+list(self.final_layer.parameters())
		self.main_optimizer=torch.optim.SGD(self.main_param,lr=0.0001)
		
		self.adv_param=list(self.decoder_input_adv_x.parameters())+list(self.decoder_adv_x.parameters())+list(self.final_layer_adv_x.parameters())
		self.adv_param+=list(self.decoder_input_adv_y.parameters())+list(self.decoder_adv_y.parameters())+list(self.final_layer_adv_y.parameters())
		self.adv_optimizer=torch.optim.SGD(self.adv_param,lr=0.0001)

	def encode(self, input,output):
		"""
		Encodes the input by passing through the encoder network
		and returns the latent codes.
		:param input: (Tensor) Input tensor to encoder [N x C x H x W]
		:return: (Tensor) List of latent codes
		"""
		result = self.encoder(input)
		mu = self.fc_mu(result)
		log_var = self.fc_var(result)

		result_2 = self.encoder_2(torch.cat([input,output],dim=1))
		mu_2 = self.fc_mu_2(result_2)
		log_var_2 = self.fc_var_2(result)

		return [mu, log_var,mu_2,log_var_2]

	def decode(self, z,mu_2):
		"""
		Maps the given latent codes
		onto the image space.
		:param z: (Tensor) [B x D]
		:return: (Tensor) [B x C x H x W]
		x [bs,obs]
		"""
		result = self.decoder_input(torch.cat([z*mu_2,z,mu_2],dim=1))
		result = self.decoder(result)
		result=self.final_layer(result)
		
		result_adv_x = self.decoder_input_adv_x(mu_2)
		result_adv_x = self.decoder_adv_x(result_adv_x)
		result_adv_x=self.final_layer_adv_x(result_adv_x)
		result_adv_y = self.decoder_input_adv_y(mu_2)
		result_adv_y = self.decoder_adv_y(result_adv_y)
		result_adv_y=self.final_layer_adv_y(result_adv_y)
		
		return result,result_adv_x,result_adv_y

	def reparameterize(self, mu, logvar):
		"""
		Reparameterization trick to sample from N(mu, var) from
		N(0,1).
		:param mu: (Tensor) Mean of the latent Gaussian [B x D]
		:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
		:return: (Tensor) [B x D]
		"""
		std = torch.exp(0.5 * logvar)
		eps = torch.randn_like(std)
		return eps * std + mu

	def forward(self, input,output, **kwargs):
		mu, log_var,mu_2,log_var_2 = self.encode(input,output)
		z = self.reparameterize(mu, log_var)
		return  [*self.decode(z,mu_2),input,output, mu, log_var]

	def loss_function(self,
					*args,
					**kwargs) -> dict:
		"""
		Computes the VAE loss function.
		KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
		:param args:
		:param kwargs:
		:return:
		"""
		recons = args[0]
		recons_adv_x = args[1]
		recons_adv_y = args[2]
		input = args[3]
		output = args[4]
		mu = args[5]
		log_var = args[6]

		kld_weight = 0#1#0.1#kwargs['M_N'] # Account for the minibatch samples from the dataset
		adv_weight=1#1
		recons_loss =F.mse_loss(recons, output)

		kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
		
		adv_loss=-F.mse_loss(recons_adv_x, input).clamp(0,10)-F.mse_loss(recons_adv_y, output).clamp(0,10)
		#1/F.mse_loss(recons_adv_x, output)+1/F.mse_loss(recons_adv_y, output)
		
		loss = recons_loss + kld_weight * kld_loss+adv_weight*adv_loss
		return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach(),'unpred':adv_loss.detach()}
	
	def loss_function_adv(self,
					*args,
					**kwargs):
		recons = args[0]
		recons_adv_x = args[1]
		recons_adv_y = args[2]
		input = args[3]
		output = args[4]
		mu = args[5]
		log_var = args[6]
		
		adv_loss=F.mse_loss(recons_adv_x, input)+F.mse_loss(recons_adv_y,output)
		return adv_loss
	
	def step(self,input,output):
		ret={}
		
		self.main_optimizer.zero_grad()
		results = self.forward(input,output)
		loss = self.loss_function(*results)
		loss['loss'].backward()
		for key in loss:
			ret[key]=loss[key].item()
		self.main_optimizer.step()
		
		self.adv_optimizer.zero_grad()
		results = self.forward(input,output)
		loss = self.loss_function_adv(*results)
		ret['adv_loss']=loss.item()
		loss.backward()
		self.adv_optimizer.step()
		
		return ret

	def generate(self,input):
		result = self.encoder(input)
		mu = self.fc_mu(result)
		return mu.detach()

class NewIB(nn.Module):
	def __init__(self, input_shape, args):
		super(NewIB, self).__init__()
		self.args = args
		self.n_agents = args.n_agents

		self.comms=[NewIBBase(input_shape,args.comm_embed_dim,args.comm_embed_dim).cuda() for _ in range(self.n_agents)]

		self.batch_size=8

	def forward(self, inputs):
		mu=[]
		for comm in self.comms:
			mu.append(comm.generate(inputs))
		mu=th.cat(mu,dim=1)
		sigma = 0*th.ones(mu.shape).cuda()

		return mu, sigma
	
	def step(self,batch,outputs):
		inputs=[batch['obs']]
		inputs+=[th.eye(self.n_agents).unsqueeze(0).unsqueeze(0).expand(inputs[0].shape[0], inputs[0].shape[1], -1, -1).cuda()]
		inputs=th.cat(inputs,dim=-1)
		inputs=inputs[:,:-1].reshape(-1,inputs.shape[-1])
		outputs=outputs.reshape(-1,self.n_agents,self.args.comm_embed_dim)

		train_dataset=Data.TensorDataset(inputs,outputs)
		loader=Data.DataLoader(train_dataset,batch_size=8,shuffle=True)

		for batch_idx, (x, y) in enumerate(loader):
			losses=[]
			for i in range(len(self.comms)):
				comm=self.comms[i]
				losses.append(comm.step(x,y[:,i,:]))
		print(losses)
		return