import numpy as np
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
from sklearn.metrics import roc_curve, auc

from utils.log import log 
from trainer.T_Com import T_Com

class T_RVAEBFA(T_Com):
	def __init__(self, train_loader, valid_loader, test_loader, cfg):
		super().__init__(train_loader, valid_loader, test_loader, cfg)
		self.lambda_energy=0.1
		self.lambda_cov_diag=0.005
		self.eps = 1e-12

	def compute_gmm_params(self, gmm_input, gamma):
		N = gamma.size(0)
		# K
		sum_gamma = torch.sum(gamma, dim=0)

		# K
		phi = (sum_gamma / N)

 
		# K x D
		mu = torch.sum(gamma.unsqueeze(-1) * gmm_input.unsqueeze(1), dim=0) / sum_gamma.unsqueeze(-1)

		# z_mu = N x K x D
		z_mu = (gmm_input.unsqueeze(1)- mu.unsqueeze(0))

		# z_mu_outer = N x K x D x D
		z_mu_outer = z_mu.unsqueeze(-1) * z_mu.unsqueeze(-2)

		# K x D x D
		cov = torch.sum(gamma.unsqueeze(-1).unsqueeze(-1) * z_mu_outer, dim = 0) / sum_gamma.unsqueeze(-1).unsqueeze(-1)

		return phi.detach(), mu.detach(), cov.detach()

	def compute_energy(self, gmm_input, phi=None, mu=None, cov=None, size_average=True):

		k, D, _ = cov.size()

		z_mu = (gmm_input.unsqueeze(1)- mu.unsqueeze(0))

		cov_inverse = []
		det_cov = []
		cov_diag = 0
		for i in range(k):
			# K x D x D
			cov_k = cov[i] + (torch.eye(D)*self.eps).to(self.device)
			cov_inverse.append(torch.inverse(cov_k).unsqueeze(0))
			det_cov.append((torch.linalg.cholesky(cov_k.cpu() * (2 * np.pi)).diag().prod()).unsqueeze(0))
			cov_diag = cov_diag + torch.sum(1 / cov_k.diag())

		# K x D x D
		cov_inverse = torch.cat(cov_inverse, dim=0)
		# K
		
		det_cov = torch.cat(det_cov).cuda()

		# N x K
		exp_term_tmp = -0.5 * torch.sum(torch.sum(z_mu.unsqueeze(-1) * cov_inverse.unsqueeze(0), dim=-2) * z_mu, dim=-1)
		max_val = torch.max((exp_term_tmp).clamp(min=0), dim=1, keepdim=True)[0]
		exp_term = torch.exp(exp_term_tmp - max_val)

		sample_energy = -max_val.squeeze() - torch.log(torch.sum(phi.unsqueeze(0) * exp_term / (torch.sqrt(det_cov)).unsqueeze(0), dim = 1) + self.eps)
		if size_average:
			sample_energy = torch.mean(sample_energy)
		return sample_energy, cov_diag


	def loss_function(self, X, recon_res):
		mu = recon_res['mu']
		logvar = recon_res['logvar']
		latent = recon_res['latent']
		X_hat = recon_res['x_hat']
		gmm_input = recon_res['gmm_input']
		gamma = recon_res['gamma']

		recon_error= torch.sum((X-X_hat)**2, dim=1).mean()

		# kl divergence
		KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(),dim=1)
		KLD = torch.mean(KLD)

		# gmm loss
		phi, mu_gmm, cov = self.compute_gmm_params(gmm_input, gamma)
		sample_energy, cov_diag = self.compute_energy(gmm_input, phi, mu_gmm, cov)
		loss = recon_error + 0.1 * KLD + self.lambda_energy * sample_energy + self.lambda_cov_diag * cov_diag
		return loss


	def train(self):

		train_time = []
		valid_loss_list=[]

		for self.epoch in tqdm(range(1,self.epochs+1)):
			# Train
			fpr_batch=tpr_batch=0
			Nab=0

			start = time.time()


			self.model.train()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.model(X)
				loss = self.loss_function(X.view(X.size(0),-1), recon_res)

				self.optim.zero_grad()
				loss.backward()
				self.optim.step()

			train_time.append(time.time()-start)

			#Valid 
			valid_loss = self.valid()
			log(f"Valid Loss: {valid_loss}")
			valid_loss_list.append(valid_loss)


			self.save(self.model.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}.tar")

		valid_loss_list = np.array(valid_loss_list)
		best_epoch = np.argmin(valid_loss_list) +1
		self.load(f"./ckpt/{self.cfg['model']}/model/{self.cfg['model']}_{str(best_epoch)}.tar")
		log(f"Best Epoch: {best_epoch}")

		return np.sum(train_time)*1000

	def valid(self):
		self.model.eval()
		valid_loss = []
		with torch.no_grad():
			n_samples=0
			N = mu_sum = cov_sum = gamma_sum = 0
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)

				recon_res = self.model(X)
				gmm_input = recon_res['gmm_input']
				gamma = recon_res['gamma']

				_, mu, cov = self.compute_gmm_params(gmm_input, gamma)

				batch_gamma_sum = torch.sum(gamma, dim=0)

				gamma_sum += batch_gamma_sum
				mu_sum += mu*batch_gamma_sum.unsqueeze(-1)
				cov_sum += cov*batch_gamma_sum.unsqueeze(-1).unsqueeze(-1)

				N += X.size(0)
			train_phi = gamma_sum/N
			train_mu = mu_sum / gamma_sum.unsqueeze(-1)
			train_cov = cov_sum / gamma_sum.unsqueeze(-1).unsqueeze(-1)

			for idx, (index, X,Y,abnormality) in enumerate(self.valid_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.model(X)
				sample_energy, cov_diag = self.compute_energy(recon_res['gmm_input'], train_phi, train_mu, train_cov, size_average=False)

				loss = sample_energy

				valid_loss.append(loss.detach().to('cpu').numpy())


		return np.mean(valid_loss)

	def test(self):
		self.model.eval()
		test_time=[]
		gt_list=[]
		score_list=[]

		with torch.no_grad():
			n_samples=0
			N = mu_sum = cov_sum = gamma_sum = 0
			start = time.time()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				recon_res = self.model(X)
				gmm_input = recon_res['gmm_input']
				gamma = recon_res['gamma']

				_, mu, cov = self.compute_gmm_params(gmm_input, gamma)

				batch_gamma_sum = torch.sum(gamma, dim=0)

				gamma_sum += batch_gamma_sum
				mu_sum += mu*batch_gamma_sum.unsqueeze(-1)
				cov_sum += cov*batch_gamma_sum.unsqueeze(-1).unsqueeze(-1)

				N += X.size(0)

			train_phi = gamma_sum/N
			train_mu = mu_sum / gamma_sum.unsqueeze(-1)
			train_cov = cov_sum / gamma_sum.unsqueeze(-1).unsqueeze(-1)

			for idx, (index, X,Y,abnormality) in enumerate(self.test_loader):
				X=X.to(self.device)

				recon_res= self.model(X)
				sample_energy, _= self.compute_energy(recon_res['gmm_input'], train_phi, train_mu, train_cov, size_average=False)

				gt_list +=list(Y.numpy())
				score_list +=list(sample_energy.to('cpu').numpy())

			test_time.append(time.time()-start)
		
		fpr, tpr, _= roc_curve(gt_list, score_list)
		AUROC = auc(fpr, tpr)
		log(f"AUROC: {AUROC}")
		
		return np.sum(test_time)*1000, AUROC

