#!/usr/bin/env python
# coding: utf-8

import numpy as np
import random
import time
import torch
import torch.nn as nn
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt 
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--winv', type=float, help='weight of invertibility', default=10.0)
parser.add_argument('--wind', type=float, help='weight of independence', default=0.01)
parser.add_argument('--mini', help='enable minimal intrinsic dimension', action='store_true')
parser.add_argument('--dataset', choices=['concat', 'split', 'fusion'], help='weight of invertibility', default='split')
args = parser.parse_args()

w_inv = args.winv
w_ind = args.wind
if args.mini:
	w_mini = 1.0
else:
	w_mini = 0.0

print('Using dataset %s' % args.dataset)
print('Using weight w_inv:%.2f, w_ind:%.2f, w_mini:%.2f' % (w_inv, w_ind, w_mini))

device = torch.device('cuda:0')
def set_seed(seed):
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
set_seed(1111)

# Feed-forward Network
class FFN(nn.Module):
	def __init__(self, input_dim, output_dim, hidden=512, n_layers=3, activation='relu'):
		super(FFN, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.hidden = hidden
		assert n_layers >= 1
		self.n_layers = n_layers
		assert activation in ['relu', 'tanh', 'sigmoid']
		act_dict = {
			'relu': nn.ReLU(),
			'tanh': nn.Tanh(),
			'sigmoid': nn.Sigmoid()
		}
		self.activation = act_dict[activation]

		self.layers = nn.ModuleList()
		if self.n_layers == 1:
			self.layers.append(nn.Linear(self.input_dim, self.output_dim))
		else:
			self.layers.append(nn.Linear(self.input_dim, self.hidden))
			for _ in range(self.n_layers-2):
				self.layers.append(self.activation)
				self.layers.append(nn.Linear(self.hidden, self.hidden))
			self.layers.append(self.activation)
			self.layers.append(nn.Linear(self.hidden, self.output_dim))
		self.to(device)
	def forward(self, x):
		output = x
		for layer in self.layers:
			output = layer(output)
		return output

# Generating data with graph s1->v1<-z, z->v2<-s2
class DataGenerator():
	# 'dims' should be list or tuple with 5 elements, representing dimension of s1, z, s2, v1, v2, respectively
	def __init__(self, dims, n_samples=None):
		self.dims = dims
		assert len(self.dims) == 5
		self.static_data = False
		if not n_samples is None:
			self.static_data = True
			self.data = self._gen(n_samples)
			self.pointer = 0
		print('Created data with dimensions:')
		for name, dim in zip(['s1', 'z', 's2', 'v1', 'v2'], dims):
			print('%s: %d' % (name, dim))
	def _batch_gen(self, batch_size=128):
		pass
	def _gen(self, n):
		data = []
		batches = []
		bs = 200
		if n <= bs:
			return self._batch_gen(n)
		for _ in range((n - 1) // bs + 1):
			batches.append(self._gen(bs))
		for i in range(5):
			data.append(torch.vstack([batch[i] for batch in batches])[:n])
		return data
	def sampling(self, n):
		if self.static_data:
			output = [var[self.pointer:self.pointer+n] for var in self.data]
			rest = n - output[0].shape[0]
			self.pointer += n
			if self.pointer >= self.data[0].shape[0]:
				self.pointer = 0
			if rest > 0:
				additional = [var[:rest] for var in self.data]
				output = [torch.vstack([a, b]) for a,b in zip(output, additional)]
				self.pointer = rest
		else:
			output = self._gen(n)
		return [var.detach().cpu().numpy() for var in output]
	def generate(self, n):
		output = self._gen(n)
		return [var.detach().cpu().numpy() for var in output]

class ConcatGenerator(DataGenerator):
	# 'dims' should be list or tuple with 5 elements, representing dimension of s1, z, s2, v1, v2, respectively
	def __init__(self, dims, n_samples=None):
		self.func_v1 = FFN(dims[0] + dims[1], dims[3], hidden = 64, n_layers=2, activation='tanh')
		self.func_v2 = FFN(dims[1] + dims[2], dims[4], hidden = 64, n_layers=2, activation='tanh')
		super(ConcatGenerator, self).__init__(dims, n_samples)
	def _batch_gen(self, batch_size=128):
		with torch.no_grad():
			s1 = torch.randn([batch_size, self.dims[0]]).to(device)
			z = torch.randn([batch_size, self.dims[1]]).to(device)
			s2 = torch.randn([batch_size, self.dims[2]]).to(device)
			v1 = self.func_v1(torch.hstack([s1, z]))
			v2 = self.func_v2(torch.hstack([z, s2]))
		return [var.detach() for var in [s1, z, s2, v1, v2]]

class SplitGenerator(DataGenerator):
	# 'dims' should be list or tuple with 5 elements, representing dimension of s1, z, s2, v1, v2, respectively
	def __init__(self, dims, n_samples=None):
		self.func_v1 = FFN(dims[0] + dims[1], dims[3], hidden = 64, n_layers=2, activation='tanh')
		self.func_v2 = FFN(dims[1] + dims[2], dims[4], hidden = 64, n_layers=2, activation='tanh')
		super(SplitGenerator, self).__init__(dims, n_samples)
	def _batch_gen(self, batch_size=128):
		with torch.no_grad():
			s1 = torch.randn([batch_size, self.dims[0]]).to(device)
			z = torch.randn([batch_size, self.dims[1]]).to(device)
			s2 = torch.randn([batch_size, self.dims[2]]).to(device)
			z_pos = nn.functional.relu(z)
			z_neg = - nn.functional.relu(-z)
			v1 = self.func_v1(torch.hstack([s1, z_pos]))
			v2 = self.func_v2(torch.hstack([z_neg, s2]))
		return [var.detach() for var in [s1, z, s2, v1, v2]]

class FusionGenerator(DataGenerator):
	# 'dims' should be list or tuple with 5 elements, representing dimension of s1, z, s2, v1, v2, respectively
	def __init__(self, dims, n_samples=None):
		self.func_s1 = FFN(dims[0], dims[3], hidden = 64, n_layers=2, activation='tanh')
		self.func_z = FFN(dims[1], dims[3], hidden = 64, n_layers=2, activation='tanh')
		self.func_v2 = FFN(dims[1] + dims[2], dims[4], hidden = 64, n_layers=2, activation='tanh')
		super(FusionGenerator, self).__init__(dims, n_samples)
	def _batch_gen(self, batch_size=128):
		with torch.no_grad():
			s1 = torch.randn([batch_size, self.dims[0]]).to(device)
			z = torch.randn([batch_size, self.dims[1]]).to(device)
			s2 = torch.randn([batch_size, self.dims[2]]).to(device)
			trans_s1 = self.func_s1(s1)
			trans_z = self.func_z(z)
			v1 = trans_s1 + trans_z
			v2 = self.func_v2(torch.hstack([z, s2]))
		return [var.detach() for var in [s1, z, s2, v1, v2]]

class CLUBModel(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(CLUBModel, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.pred_mu = FFN(self.input_dim, self.output_dim)
		self.pred_logvar = FFN(self.input_dim, self.output_dim)
	def forward(self, x, y):
		mu, logvar = self.pred_mu(x), self.pred_logvar(x)
		nll_loss = torch.mean((mu - y) ** 2 / logvar.exp() + logvar) # unnormalized
		permed_index = torch.randperm(y.shape[0])
		club_loss = torch.mean((mu - y[permed_index]) ** 2 / logvar.exp() / 2.0) - torch.mean((mu - y) ** 2 / logvar.exp() / 2.0)
		return nll_loss, club_loss

class Decoder(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(Decoder, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.pred_mu = FFN(self.input_dim, self.output_dim)
	def forward(self, x, y):
		mu= self.pred_mu(x)
		nll_loss = torch.mean((mu - y) ** 2)
		return nll_loss

# Identification model
class Model(nn.Module):
	# 'vis_dims' should be list or tuple with 2 elements, representing dimension of v1, v2, respectively
	# We assume the dimension of latent variables as d_s1 = d_v1, d_z = d_v1 + d_v2, d_s2 = d_v2
	def __init__(self, vis_dims, latent_dims=None):
		super(Model, self).__init__()
		assert len(vis_dims) == 2
		self.v1_dim, self.v2_dim = vis_dims
		if latent_dims is None:  
			self.z_dim = self.v1_dim + self.v2_dim
			self.s1_dim = self.v1_dim
			self.s2_dim = self.v2_dim
		else:
			assert len(latent_dims) == 3
			self.s1_dim, self.z_dim, self.s2_dim = latent_dims
		self.enc_ml = nn.ModuleList()
		self.aux_ml = nn.ModuleList()
		
		# encoder & decoder
		self.encoder = FFN(self.v1_dim + self.v2_dim, self.z_dim + self.s1_dim + self.s2_dim) # [v1, v2] -> [s1, z, s2]
		self.decoder_v1 = Decoder(self.z_dim + self.s1_dim, self.v1_dim) # [s1, z] -> v1
		self.decoder_v2 = Decoder(self.z_dim + self.s2_dim, self.v2_dim) # [z, s2] -> v2
		self.enc_ml.extend([self.encoder, self.decoder_v1, self.decoder_v2])
		
		# auxiliary predictors
		self.indm_s1_zs2 = CLUBModel(self.s1_dim, self.z_dim + self.s2_dim)
		self.indm_z_s1s2 = CLUBModel(self.z_dim, self.s1_dim + self.s2_dim)
		self.indm_s2_s1z = CLUBModel(self.s2_dim, self.s1_dim + self.z_dim)
		self.aux_ml.extend([self.indm_s1_zs2, self.indm_z_s1s2, self.indm_s2_s1z])

		self.optim_enc = torch.optim.AdamW(self.enc_ml.parameters(), lr=1e-3)
		self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=5000, gamma=0.2)
		self.optim_aux = torch.optim.AdamW(self.aux_ml.parameters(), lr=1e-3)
		self.scheduler_aux = torch.optim.lr_scheduler.StepLR(self.optim_aux, step_size=5000, gamma=0.2)
		self.to(device)
	def forward(self, v1, v2, mode='training'):
		latent_code = self.encoder(torch.hstack([v1, v2]))
		s1 = latent_code[:, :self.s1_dim]
		z = latent_code[:, self.s1_dim:latent_code.shape[1]-self.s2_dim]
		s2 = latent_code[:, latent_code.shape[1]-self.s2_dim:]
		
		loss_dic = {}
		
		loss_inv = (self.decoder_v1(torch.hstack([s1, z]), v1) + self.decoder_v2(torch.hstack([z, s2]), v2)) / 2
		loss_dic['loss_inv'] = loss_inv

		nll_zs2, ind_s1 = self.indm_s1_zs2(s1, torch.hstack([z, s2]))
		nll_s1s2, ind_z = self.indm_z_s1s2(z, torch.hstack([s1, s2]))
		nll_s1z, ind_s2 = self.indm_s2_s1z(s2, torch.hstack([s1, z]))
		loss_ind = (ind_s1 + ind_z + ind_s2) / 3
		loss_pred = (nll_zs2 + nll_s1s2 + nll_s1z) / 3
		loss_dic['loss_ind'] = loss_ind
		loss_dic['loss_pred'] = loss_pred

		if mode == 'training':
			return loss_dic
		else:
			return s1, z, s2
	
	def step(self, v1, v2):
		self.train()
		v1, v2 = torch.tensor(v1).to(device), torch.tensor(v2).to(device)

		for i in range(5):
			loss_dic = self.forward(v1, v2)
			loss_aux = loss_dic['loss_pred']
			self.optim_aux.zero_grad()
			loss_aux.backward()
			self.optim_aux.step()

		loss_dic = self.forward(v1, v2)
		loss_enc = w_inv * loss_dic['loss_inv'] + w_ind * loss_dic['loss_ind']
		
		self.optim_enc.zero_grad()
		loss_enc.backward()
		self.optim_enc.step()

		self.scheduler_enc.step()
		self.scheduler_aux.step()

		loss_dic['loss_aux'] = loss_aux
		loss_dic['loss_enc'] = loss_enc
		for k, v in loss_dic.items():
			loss_dic[k] = v.detach().cpu()
		return loss_dic
	def predict(self, v1, v2):
		self.eval()
		v1, v2 = torch.tensor(v1).to(device), torch.tensor(v2).to(device)
		batch_size = 200
		num = v1.shape[0]
		all_pred = []
		for i in range((num - 1) // batch_size + 1):
			batch_v1  = v1[batch_size*i:batch_size*(i+1)]
			batch_v2  = v2[batch_size*i:batch_size*(i+1)]
			pred = self.forward(batch_v1, batch_v2, mode='test')
			all_pred.append([value.detach().cpu().numpy() for value in pred])
		return [np.vstack([batch[i] for batch in all_pred]) for i in range(3)]

class Predictor(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(Predictor, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.batch_size = 200
		self.threshold = 1e-2
		self.model = FFN(input_dim, output_dim, hidden=1024, n_layers=3)
		self.optim = torch.optim.AdamW(self.parameters(), lr=1e-3)
		self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
			self.optim, 
			mode='min', 
			factor=0.1, 
			patience=5, 
			verbose=False, 
			threshold=self.threshold, 
			threshold_mode='rel', 
			cooldown=0, 
			min_lr=0, 
			eps=1e-05)
		self.to(device)
	def forward(self, x):
		return self.model(x)
	def _step(self, x, y):
		x, y = torch.tensor(x).to(device), torch.tensor(y).to(device)
		loss_fn = nn.MSELoss()
		pred_y = self.forward(x)
		loss = loss_fn(y, pred_y)
		self.optim.zero_grad()
		loss.backward()
		self.optim.step()
		return loss.detach().cpu()
	def predict(self, x):
		self.eval()
		x = torch.tensor(x).to(device)
		num = x.shape[0]
		all_pred_y = []
		for i in range((num - 1) // self.batch_size + 1):
			batch_x  = x[self.batch_size*i:self.batch_size*(i+1)]
			pred_y = self.forward(batch_x)
			all_pred_y.append(pred_y.detach().cpu().numpy())
		return np.vstack(all_pred_y)
	def fit_with_val(self, x, y, silent=True):
		self.train()
		num = x.shape[0]
		permed_index = np.random.permutation(num)
		num_train = np.round(num * 0.8).astype(np.int32)
		num_val = num - num_train
		train_x, train_y = x[permed_index[:num_train]], y[permed_index[:num_train]]
		val_x, val_y = x[permed_index[num_train:]], y[permed_index[num_train:]]
		n_epochs = 500
		es_steps = 10 # early stop steps
		es_count = 0
		es_loss = np.inf
		best_params = self.state_dict()
		best_ep = 0
		for e in range(n_epochs):
			for i in range((num_train - 1) // self.batch_size + 1):
				batch_x = train_x[self.batch_size*i:self.batch_size*(i+1)]
				batch_y = train_y[self.batch_size*i:self.batch_size*(i+1)]
				loss = self._step(batch_x, batch_y)
			val_loss = 0.0
			for i in range((num_val - 1) // self.batch_size + 1):
				batch_x = torch.tensor(val_x[self.batch_size*i:self.batch_size*(i+1)]).to(device)
				batch_y = torch.tensor(val_y[self.batch_size*i:self.batch_size*(i+1)]).to(device)
				with torch.no_grad():
					pred_y = self.forward(batch_x)
					loss = torch.sum(torch.mean((pred_y - batch_y) ** 2, dim=-1))
				val_loss += loss.detach().cpu()
			val_loss /= num_val
			self.scheduler.step(val_loss)
			if not silent and e % 10 == 0:
				print('Epoch %d, validataion loss: %f' % (e, val_loss))
			if val_loss < es_loss * (1 - self.threshold):
				es_loss = val_loss
				es_count = 0
				best_params = self.state_dict()
				best_ep = e
			else:
				es_count += 1
				if es_count == es_steps:
					if not silent:
						print('Early stopped at epoch %d, use params of epoch %d, loss: %f' % (e, best_ep, val_loss))
					self.load_state_dict(best_params)
					break
		return val_loss
	def evaluate(self, x, y):
		self.eval()
		pred_y = self.predict(x)
		mse = np.mean((pred_y - y) ** 2, axis=0)
		avg_y = np.mean(y, axis=0)
		r_square = 1.0 - mse / (np.mean((y - avg_y) ** 2, axis=0) + 1e-20)
		r_square[r_square<0] = 0.0
		return np.mean(mse), np.mean(r_square)
	
def normalize(x):
	mu = np.mean(x, axis=0)
	std = np.std(x, axis=0)
	return (x - mu) / std
def eval(x, y, silence=True):
	n, dim_x = x.shape
	dim_y = y.shape[-1]
	x, y = normalize(x), normalize(y)
	n_test = n//4
	train_x = x[:-n_test]
	train_y = y[:-n_test]
	test_x = x[-n_test:]
	test_y = y[-n_test:]
	forward_pred = Predictor(dim_x, dim_y)
	forward_pred.fit_with_val(train_x, train_y, silence)
	f_mse, f_r2 = forward_pred.evaluate(test_x, test_y)
	backward_pred = Predictor(dim_y, dim_x)
	backward_pred.fit_with_val(train_y, train_x, silence)
	b_mse, b_r2 = backward_pred.evaluate(test_y, test_x)
	return f_r2, b_r2

def F1(p1, p2):
	eps = 1e-20
	p1 = max(p1, eps)
	p2 = max(p2, eps)
	return 2*p1*p2/(p1+p2)

var_dims = [3, 5, 4, 10, 10]
bs = 100
n_epochs = 20001
if args.dataset == 'concat':
	sampler = ConcatGenerator(var_dims, n_samples=100000)
elif args.dataset == 'split':
	sampler = SplitGenerator(var_dims, n_samples=100000)
elif args.dataset == 'fusion':
	sampler = FusionGenerator(var_dims, n_samples=100000)
else:
	raise ValueException('Invalid dataset type: %s' % args.dataset)

test_data = sampler.generate(20000)
all_curves = []
#set_seed(int(time.time()))

def run(index, max_index):

	if args.mini:
		model = Model(var_dims[-2:], latent_dims=[3, 5, 4])
	else:
		model = Model(var_dims[-2:], latent_dims=[3, 7, 4])
	
	print('Training %d/%d...' % (index, max_index))

	saved_r2 = []

	t0 = time.time()
	for e in range(n_epochs):
		epid = e
		_, _, _, batch_v1, batch_v2 = sampler.sampling(bs)
		loss_dic = model.step(batch_v1, batch_v2)
		
		if epid % 100 == 0:
			print('Epoch %d, Encoder loss: %f, Auxiliary loss: %f Invertibility: %f, Independence: %f' 
				% (epid, loss_dic['loss_enc'], loss_dic['loss_aux'], loss_dic['loss_inv'], loss_dic['loss_ind']))

		if epid % 1000 == 0:
			t1 = time.time()
			print('Time: %d' % (t1-t0))
			print('Epoch %d, Learning Rate: %f ...' % (epid, model.scheduler_enc.get_last_lr()[0]))
			s1, z, s2, v1, v2 = test_data
			pred_s1, pred_z, pred_s2 = model.predict(v1, v2)
			
			print('Epoch %d, Evaluating ...' % epid)
			s1_fr2, s1_br2 = eval(pred_s1, s1)
			print('s1->s1_gt', s1_fr2, s1_br2)
			z_fr2, z_br2 = eval(pred_z, z)
			print('z->z_gt', z_fr2, z_br2)
			s2_fr2, s2_br2 = eval(pred_s2, s2)
			print('s2->s2_gt', s2_fr2, s2_br2)
			avg_latent_r2 = (F1(s1_fr2, s1_br2) + F1(z_fr2, z_br2) + F1(s2_fr2, s2_br2)) / 3
			print('>'*20, 'Epoch %d, Average R2-F1: %f' % (epid, avg_latent_r2))

			print('Epoch %d, Evaluating invertibility ...' % e)
			v = np.hstack([v1, v2])
			pred = np.hstack([pred_s1, pred_z, pred_s2])
			fr2, br2 = eval(pred, v)
			avg_inv_r2 = F1(fr2, br2)
			print('c->v:', fr2, br2, avg_inv_r2)

			print('Epoch %d, Evaluating independence ...' % e)
			ind_s1_fr2, ind_s1_br2 = eval(pred_s1, np.hstack([pred_z, pred_s2]))
			print('s1->zs2', ind_s1_fr2, ind_s1_br2)
			ind_z_fr2, ind_z_br2 = eval(pred_z, np.hstack([pred_s1, pred_s2]))
			print('z->s1s2', ind_z_fr2, ind_z_br2)
			ind_s2_fr2, ind_s2_br2 = eval(pred_s2, np.hstack([pred_s1, pred_z]))
			print('s2->s1z', ind_s2_fr2, ind_s2_br2)

			ind_list = [[ind_s1_fr2, ind_s1_br2], [ind_z_fr2, ind_z_br2], [ind_s2_fr2, ind_s2_br2]]
			ind_list = [F1(1 - x[0], 1 - x[1]) for x in ind_list]
			avg_ind_r2 = np.mean(ind_list)
			print('average independence: %f' % avg_ind_r2)

			saved_r2.append([avg_latent_r2, avg_inv_r2, avg_ind_r2, s1_fr2, s1_br2, z_fr2, z_br2, s2_fr2, s2_br2])
	all_curves.append(np.array(saved_r2))

for i in range(5):
	run(i, 5)
all_curves = np.array(all_curves)
np.save('%.2f-%.2f-%.2f.npy' % (w_inv, w_ind, w_mini), all_curves)
print('done')
