#!/usr/bin/env python
# coding: utf-8

import numpy as np
import random
import time
import torch
import torch.nn as nn
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt 
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--winv', type=float, help='weight of invertibility', default=10.0)
parser.add_argument('--wind', type=float, help='weight of independence', default=0.01)
parser.add_argument('--mini', help='enable minimal intrinsic dimension', action='store_true')
args = parser.parse_args()

w_inv = args.winv
w_ind = args.wind
if args.mini:
	w_mini = 1.0
else:
	w_mini = 0.0

print('Using weight w_inv:%.2f, w_ind:%.2f, w_mini:%.2f' % (w_inv, w_ind, w_mini))

device = torch.device('cuda:0')
def set_seed(seed):
	random.seed(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
set_seed(1111)

# Feed-forward Network
class FFN(nn.Module):
	def __init__(self, input_dim, output_dim, hidden=512, n_layers=3, activation='relu'):
		super(FFN, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.hidden = hidden
		assert n_layers >= 1
		self.n_layers = n_layers
		assert activation in ['relu', 'tanh', 'sigmoid']
		act_dict = {
			'relu': nn.ReLU(),
			'tanh': nn.Tanh(),
			'sigmoid': nn.Sigmoid()
		}
		self.activation = act_dict[activation]

		self.layers = nn.ModuleList()
		if self.n_layers == 1:
			self.layers.append(nn.Linear(self.input_dim, self.output_dim))
		else:
			self.layers.append(nn.Linear(self.input_dim, self.hidden))
			for _ in range(self.n_layers-2):
				self.layers.append(self.activation)
				self.layers.append(nn.Linear(self.hidden, self.hidden))
			self.layers.append(self.activation)
			self.layers.append(nn.Linear(self.hidden, self.output_dim))
		self.to(device)
	def forward(self, x):
		output = x
		for layer in self.layers:
			output = layer(output)
		return output
	def info(self):
		for layer in self.layers:
			if not isinstance(layer, nn.Linear):
				continue
			print('Rank of linear weight:', np.linalg.matrix_rank(layer.weight.detach().cpu().numpy()))

# Generating data
class DataGenerator():
	# 'dims' should be list or tuple with 10 elements (size 3 multi-varible model), representing dimension of s_i and v_j, respectively
	def __init__(self, dims, n_samples=None):
		self.dims = dims
		assert len(self.dims) == 10
		self.func_v1 = FFN(dims[0] + dims[2] + dims[4] + dims[6], dims[7], hidden = 64, n_layers=2, activation='tanh')
		self.func_v2 = FFN(dims[1] + dims[2] + dims[5] + dims[6], dims[8], hidden = 64, n_layers=2, activation='tanh')
		self.func_v3 = FFN(dims[3] + dims[4] + dims[5] + dims[6], dims[9], hidden = 64, n_layers=2, activation='tanh')
		#self.func_v1.info()
		#self.func_v2.info()
		#self.func_v3.info()
		self.static_data = False
		if not n_samples is None:
			self.static_data = True
			self.data = self._gen(n_samples)
			self.pointer = 0
		print('Created data with dimensions:')
		for i in range(7):
			print('s%d: %d' % (i, self.dims[i]))
		for i in range(3):
			print('v%d: %d' % (i, self.dims[7+i]))
	def _batch_gen(self, batch_size=128):
		with torch.no_grad():
			s = []
			for i in range(7):
				dic = {}
				dic['raw'] = torch.randn([batch_size, self.dims[i]]).to(device)
				s.append(dic)
			v1 = self.func_v1(torch.hstack([s[0]['raw'], s[2]['raw'], s[4]['raw'], s[6]['raw']]))
			v2 = self.func_v2(torch.hstack([s[1]['raw'], s[2]['raw'], s[5]['raw'], s[6]['raw']]))
			v3 = self.func_v3(torch.hstack([s[3]['raw'], s[4]['raw'], s[5]['raw'], s[6]['raw']]))
		return [var['raw'].detach() for var in s] + [var.detach() for var in [v1, v2, v3]]
	def _gen(self, n):
		data = []
		batches = []
		bs = 200
		if n <= bs:
			return self._batch_gen(n)
		for _ in range((n - 1) // bs + 1):
			batches.append(self._batch_gen(bs))
		for i in range(10):
			data.append(torch.vstack([batch[i] for batch in batches])[:n])
		return data
	def sampling(self, n):
		if self.static_data:
			output = [var[self.pointer:self.pointer+n] for var in self.data]
			rest = n - output[0].shape[0]
			self.pointer += n
			if self.pointer >= self.data[0].shape[0]:
				self.pointer = 0
			if rest > 0:
				additional = [var[:rest] for var in self.data]
				output = [torch.vstack([a, b]) for a,b in zip(output, additional)]
				self.pointer = rest
		else:
			output = self._gen(n)
		return [var.detach().cpu().numpy() for var in output]
	def generate(self, n):
		output = self._gen(n)
		return [var.detach().cpu().numpy() for var in output]

class CLUBModel(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(CLUBModel, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.pred_mu = FFN(self.input_dim, self.output_dim)
		self.pred_logvar = FFN(self.input_dim, self.output_dim)
	def forward(self, x, y):
		mu, logvar = self.pred_mu(x), self.pred_logvar(x)
		nll_loss = torch.mean((mu - y) ** 2 / logvar.exp() + logvar) # unnormalized
		permed_index = torch.randperm(y.shape[0])
		club_loss = torch.mean((mu - y[permed_index]) ** 2 / logvar.exp() / 2.0) - torch.mean((mu - y) ** 2 / logvar.exp() / 2.0)
		return nll_loss, club_loss

class Decoder(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(Decoder, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.pred_mu = FFN(self.input_dim, self.output_dim)
	def forward(self, x, y):
		mu= self.pred_mu(x)
		nll_loss = torch.mean((mu - y) ** 2)
		return nll_loss

# Identification model
class Model(nn.Module):
	# 'vis_dims' should be list or tuple with 3 elements, representing dimension of v1, v2, v3, respectively
	def __init__(self, vis_dims, latent_dims):
		super(Model, self).__init__()
		assert len(vis_dims) == 3
		self.vdims = np.array(vis_dims)
		self.sdims = np.array(latent_dims)
		assert len(latent_dims) == 7

		self.enc_ml = nn.ModuleList()
		self.aux_ml = nn.ModuleList()
		
		# encoder & decoder
		self.encoder = FFN(sum(self.vdims), sum(self.sdims)) # v -> c
		self.decoder_v1 = Decoder(sum(self.sdims[[0,2,4,6]]), self.vdims[0])
		self.decoder_v2 = Decoder(sum(self.sdims[[1,2,5,6]]), self.vdims[1])
		self.decoder_v3 = Decoder(sum(self.sdims[[3,4,5,6]]), self.vdims[2])
		self.enc_ml.extend([self.encoder, self.decoder_v1, self.decoder_v2, self.decoder_v3])
		
		# auxiliary predictors
		self.indms = []
		for i in range(7):
			input_dim = self.sdims[i]
			output_dim = sum(self.sdims) - input_dim
			clubm = CLUBModel(input_dim, output_dim)
			self.indms.append(clubm)
		self.aux_ml.extend(self.indms)

		self.optim_enc = torch.optim.AdamW(self.enc_ml.parameters(), lr=1e-3)
		self.scheduler_enc = torch.optim.lr_scheduler.StepLR(self.optim_enc, step_size=5000, gamma=0.2)
		self.optim_aux = torch.optim.AdamW(self.aux_ml.parameters(), lr=1e-3)
		self.scheduler_aux = torch.optim.lr_scheduler.StepLR(self.optim_aux, step_size=5000, gamma=0.2)
		self.to(device)
	def forward(self, v1, v2, v3, mode='training'):
		latent_code = self.encoder(torch.hstack([v1, v2, v3]))
		s = []
		for i in range(7):
			s.append(latent_code[:, np.sum(self.sdims[:i]):np.sum(self.sdims[:i+1])])
		
		loss_dic = {}
		
		loss_v1 = self.decoder_v1(torch.hstack([s[i] for i in [0,2,4,6]]), v1)
		loss_v2 = self.decoder_v2(torch.hstack([s[i] for i in [1,2,5,6]]), v2)
		loss_v3 = self.decoder_v3(torch.hstack([s[i] for i in [3,4,5,6]]), v3)
		loss_inv = (loss_v1 + loss_v2 + loss_v3) / 3
		loss_dic['loss_inv'] = loss_inv

		nlls = []
		indlosses = []
		for i in range(7):
			nll, indloss = self.indms[i](s[i], torch.hstack(s[:i] + s[i+1:]))
			nlls.append(nll)
			indlosses.append(indloss)
		loss_ind = sum(indlosses) / 7
		loss_pred = sum(nlls) / 7
		loss_dic['loss_ind'] = loss_ind
		loss_dic['loss_pred'] = loss_pred

		if mode == 'training':
			return loss_dic
		else:
			return s
	
	def step(self, v1, v2, v3):
		self.train()
		v1, v2, v3 = torch.tensor(v1).to(device), torch.tensor(v2).to(device), torch.tensor(v3).to(device)

		for i in range(5):
			loss_dic = self.forward(v1, v2, v3)
			loss_aux = loss_dic['loss_pred']
			self.optim_aux.zero_grad()
			loss_aux.backward()
			self.optim_aux.step()

		loss_dic = self.forward(v1, v2, v3)
		loss_enc = w_inv * loss_dic['loss_inv'] + w_ind * loss_dic['loss_ind']
		
		self.optim_enc.zero_grad()
		loss_enc.backward()
		self.optim_enc.step()

		self.scheduler_enc.step()
		self.scheduler_aux.step()

		loss_dic['loss_aux'] = loss_aux
		loss_dic['loss_enc'] = loss_enc
		for k, v in loss_dic.items():
			loss_dic[k] = v.detach().cpu()
		return loss_dic
	def predict(self, v1, v2, v3):
		self.eval()
		v1, v2, v3 = torch.tensor(v1).to(device), torch.tensor(v2).to(device), torch.tensor(v3).to(device)
		batch_size = 200
		num = v1.shape[0]
		all_pred = []
		for i in range((num - 1) // batch_size + 1):
			batch_v1  = v1[batch_size*i:batch_size*(i+1)]
			batch_v2  = v2[batch_size*i:batch_size*(i+1)]
			batch_v3  = v3[batch_size*i:batch_size*(i+1)]
			pred = self.forward(batch_v1, batch_v2, batch_v3, mode='test')
			all_pred.append([value.detach().cpu().numpy() for value in pred])
		return [np.vstack([batch[i] for batch in all_pred]) for i in range(7)]

class Predictor(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(Predictor, self).__init__()
		self.input_dim = input_dim
		self.output_dim = output_dim
		self.batch_size = 200
		self.threshold = 1e-2
		self.model = FFN(input_dim, output_dim, hidden=1024, n_layers=3)
		self.optim = torch.optim.AdamW(self.parameters(), lr=1e-3)
		self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
			self.optim, 
			mode='min', 
			factor=0.1, 
			patience=5, 
			verbose=False, 
			threshold=self.threshold, 
			threshold_mode='rel', 
			cooldown=0, 
			min_lr=0, 
			eps=1e-05)
		self.to(device)
	def forward(self, x):
		return self.model(x)
	def _step(self, x, y):
		x, y = torch.tensor(x).to(device), torch.tensor(y).to(device)
		loss_fn = nn.MSELoss()
		pred_y = self.forward(x)
		loss = loss_fn(y, pred_y)
		self.optim.zero_grad()
		loss.backward()
		self.optim.step()
		return loss.detach().cpu()
	def predict(self, x):
		self.eval()
		x = torch.tensor(x).to(device)
		num = x.shape[0]
		all_pred_y = []
		for i in range((num - 1) // self.batch_size + 1):
			batch_x  = x[self.batch_size*i:self.batch_size*(i+1)]
			pred_y = self.forward(batch_x)
			all_pred_y.append(pred_y.detach().cpu().numpy())
		return np.vstack(all_pred_y)
	def fit_with_val(self, x, y, silent=True):
		self.train()
		num = x.shape[0]
		permed_index = np.random.permutation(num)
		num_train = np.round(num * 0.8).astype(np.int32)
		num_val = num - num_train
		train_x, train_y = x[permed_index[:num_train]], y[permed_index[:num_train]]
		val_x, val_y = x[permed_index[num_train:]], y[permed_index[num_train:]]
		n_epochs = 500
		es_steps = 10 # early stop steps
		es_count = 0
		es_loss = np.inf
		best_params = self.state_dict()
		best_ep = 0
		for e in range(n_epochs):
			for i in range((num_train - 1) // self.batch_size + 1):
				batch_x = train_x[self.batch_size*i:self.batch_size*(i+1)]
				batch_y = train_y[self.batch_size*i:self.batch_size*(i+1)]
				loss = self._step(batch_x, batch_y)
			val_loss = 0.0
			for i in range((num_val - 1) // self.batch_size + 1):
				batch_x = torch.tensor(val_x[self.batch_size*i:self.batch_size*(i+1)]).to(device)
				batch_y = torch.tensor(val_y[self.batch_size*i:self.batch_size*(i+1)]).to(device)
				with torch.no_grad():
					pred_y = self.forward(batch_x)
					loss = torch.sum(torch.mean((pred_y - batch_y) ** 2, dim=-1))
				val_loss += loss.detach().cpu()
			val_loss /= num_val
			self.scheduler.step(val_loss)
			if not silent and e % 10 == 0:
				print('Epoch %d, validataion loss: %f' % (e, val_loss))
			if val_loss < es_loss * (1 - self.threshold):
				es_loss = val_loss
				es_count = 0
				best_params = self.state_dict()
				best_ep = e
			else:
				es_count += 1
				if es_count == es_steps:
					if not silent:
						print('Early stopped at epoch %d, use params of epoch %d, loss: %f' % (e, best_ep, val_loss))
					self.load_state_dict(best_params)
					break
		return val_loss
	def evaluate(self, x, y):
		self.eval()
		pred_y = self.predict(x)
		mse = np.mean((pred_y - y) ** 2, axis=0)
		avg_y = np.mean(y, axis=0)
		r_square = 1.0 - mse / (np.mean((y - avg_y) ** 2, axis=0) + 1e-20)
		r_square[r_square<0] = 0.0
		return np.mean(mse), np.mean(r_square)
	
def normalize(x):
	mu = np.mean(x, axis=0)
	std = np.std(x, axis=0)
	return (x - mu) / std
def eval(x, y, silence=True):
	n, dim_x = x.shape
	dim_y = y.shape[-1]
	x, y = normalize(x), normalize(y)
	n_test = n//4
	train_x = x[:-n_test]
	train_y = y[:-n_test]
	test_x = x[-n_test:]
	test_y = y[-n_test:]
	forward_pred = Predictor(dim_x, dim_y)
	forward_pred.fit_with_val(train_x, train_y, silence)
	f_mse, f_r2 = forward_pred.evaluate(test_x, test_y)
	backward_pred = Predictor(dim_y, dim_x)
	backward_pred.fit_with_val(train_y, train_x, silence)
	b_mse, b_r2 = backward_pred.evaluate(test_y, test_x)
	return f_r2, b_r2

def F1(p1, p2):
	eps = 1e-20
	p1 = max(p1, eps)
	p2 = max(p2, eps)
	return 2*p1*p2/(p1+p2)

var_dims = [2] * 7 + [10] * 3
sampler = DataGenerator(var_dims, n_samples=100000)
bs = 100
n_epochs = 20001

test_data = sampler.generate(20000)
all_curves = []
set_seed(int(time.time()))

def run(index, max_index):

	if args.mini:
		model = Model(var_dims[-3:], latent_dims=var_dims[:7])
	else:
		model = Model(var_dims[-3:], latent_dims=[3]*7)
	
	print('Training %d/%d...' % (index, max_index))

	saved_r2 = []

	t0 = time.time()
	for e in range(n_epochs):
		epid = e
		batch_v1, batch_v2, batch_v3 = sampler.sampling(bs)[-3:]
		loss_dic = model.step(batch_v1, batch_v2, batch_v3)
		
		if epid % 100 == 0:
			print('Epoch %d, Encoder loss: %f, Auxiliary loss: %f Invertibility: %f, Independence: %f' 
				% (epid, loss_dic['loss_enc'], loss_dic['loss_aux'], loss_dic['loss_inv'], loss_dic['loss_ind']))

		if epid % 1000 == 0:
			t1 = time.time()
			print('Time: %d' % (t1-t0))
			print('Epoch %d, Learning Rate: %f ...' % (epid, model.scheduler_enc.get_last_lr()[0]))
			s = test_data[:7]
			v1, v2, v3 = test_data[-3:]
			pred_s = model.predict(v1, v2, v3)
			
			print('Epoch %d, Evaluating ...' % e)
			r2s = []
			for i in range(7):
				fr2, br2 = eval(pred_s[i], s[i])
				r2f1 = F1(fr2, br2)
				r2s.append(r2f1)
				print('s%d->s%d_gt' % (i+1, i+1), fr2, br2)
			avg_r2 = np.mean(r2s)
			print('>'*20, 'Epoch %d, Average R2-F1: %f' % (e, avg_r2))
			saved_r2.append(avg_r2)
	all_curves.append(np.array(saved_r2))

for i in range(5):
	run(i, 5)
all_curves = np.array(all_curves)
np.save('%.2f-%.2f-%.2f.npy' % (w_inv, w_ind, w_mini), all_curves)
print('done')
