import numpy as np
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
from sklearn.metrics import roc_curve, auc

from utils.log import log 
from utils.rejection import reject
from trainer.T_Com import T_Com
from utils.entropy_loss import EntropyLossEncap 

class T_DSAD(T_Com):
	def __init__(self, train_loader, valid_loader, test_loader, cfg):
		super().__init__(train_loader, valid_loader, test_loader, cfg)

	def load(self, path):
		self.model_params = torch.load(path)
		self.c = self.model_params['c']
		self.model.load_state_dict(self.model_params['model'])

	def set_pred_model(self):
		ae = self.ae_model.state_dict()
		model = self.model.state_dict()

		ae = {k:v for k,v in ae.items() if k in model}
		model.update(ae)
		self.model.load_state_dict(model)

	def set_c(self):
		c = torch.zeros(self.rep_dim, device=self.device)

		self.model.eval()
		with torch.no_grad():
			total_batch = len(self.train_loader)
			n_samples =0
			samples = []

			w = torch.ones((self.cfg['N_train'], 1)).to(self.device)
			for batch_idx, (index, X, Y, abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.ae_model(X)
				X_hat = recon_res['output']

				loss_dict= reject(X, X_hat, Y, abnormality, 0, self.cfg['q'])

				hypothesis=self.model(X)

				mask = torch.ones_like(hypothesis)
				mask2 = mask[torch.where(abnormality==0)]
				mask2[loss_dict['q_indices']]=0
				mask[torch.where(abnormality==0)] = mask2

				c += torch.sum(mask*hypothesis, dim=0).detach()
				n_samples += torch.sum(mask).detach().to('cpu')


			c /= n_samples

			c[(abs(c) < 0.1) & (c <0)] = -0.1
			c[(abs(c) < 0.1) & (c >0)] = 0.1

			return c


	def train(self):

		train_time = []
		valid_loss_list=[]

		# Autoencoder Train
		for self.ae_epoch in tqdm(range(1,self.ae_epochs+1)):

			start = time.time()


			self.ae_model.train()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.ae_model(X)
				X_hat = recon_res['output']
				c = recon_res['latent'].detach()

				loss_dict= reject(X, X_hat, Y, abnormality, 0, self.cfg['q'])
				loss = self.loss(X, X_hat, abnormality, loss_dict['q_indices'], 0, self.cfg['s'], self.cfg['a'])
				loss[torch.where(abnormality==1)]=0

				loss = loss.mean()


				self.ae_optim.zero_grad()
				loss.backward()
				self.ae_optim.step()

			train_time.append(time.time()-start)

			valid_loss = self.ae_valid()
			log(f"Valid Loss: {valid_loss}")
			valid_loss_list.append(valid_loss)

			self.save(self.ae_model.state_dict(), f"AE_{self.cfg['model']}_{str(self.ae_epoch)}.tar")

		valid_loss_list = np.array(valid_loss_list)
		best_epoch = np.argmin(valid_loss_list) +1
		self.ae_model.load_state_dict(torch.load(f"./ckpt/{self.cfg['model']}/model/AE_{self.cfg['model']}_{str(best_epoch)}.tar"))
		log(f"Best Epoch: {best_epoch}")




		# Load Best epoch Autoencoder model
		self.set_pred_model()

		# Set center point
		self.c = self.set_c()

		auroc_list=[]
		valid_loss_list=[]

		# DSAD Train

		for self.epoch in tqdm(range(1,self.epochs+1)):

			start = time.time()


	
			self.model.train()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				c= self.model(X)

				loss_dict= reject(self.c, c, Y, abnormality, 0, self.cfg['q'])
				loss = self.loss(self.c, c, abnormality, loss_dict['q_indices'], 0, self.cfg['s'], self.cfg['a'])

				ab_loss= (torch.sum((c-self.c)**2, dim=1) + 1e-12)**-1
				loss = torch.where(abnormality==0, loss, ab_loss)

				loss = loss.mean()


				self.optim.zero_grad()
				loss.backward()
				self.optim.step()

			train_time.append(time.time()-start)
			#Valid 
			valid_loss = self.valid()
			log(f"Valid Loss: {valid_loss}")
			valid_loss_list.append(valid_loss)

			self.save({"c": self.c, "model": self.model.state_dict()}, f"{self.cfg['model']}_{str(self.epoch)}.tar")


		valid_loss_list = np.array(valid_loss_list)
		best_epoch = np.argmin(valid_loss_list) +1
		self.load(f"./ckpt/{self.cfg['model']}/model/{self.cfg['model']}_{str(best_epoch)}.tar")
		log(f"Best Epoch: {best_epoch}")


		return np.sum(train_time)*1000

	def ae_valid(self):
		self.ae_model.eval()
		valid_loss = []
		with torch.no_grad():
			n_samples=0
			for idx, (index, X,Y,abnormality) in enumerate(self.valid_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.ae_model(X)
				X_hat = recon_res['output']
				c = recon_res['latent']

				loss_dict= reject(X, X_hat, Y, abnormality, 0, self.cfg['q'])
				loss= torch.sum((X-X_hat)**2, dim=(1,2,3)).detach()

				mask = torch.ones_like(loss)
				mask2 = mask[torch.where(abnormality==0)]
				mask2[loss_dict['q_indices']]=0
				mask[torch.where(abnormality==0)] = mask2
				mask[torch.where(abnormality==1)]=0
				n_samples += torch.sum(mask).detach().to('cpu')
				loss = (loss*mask).sum()

				valid_loss.append(loss.detach().to('cpu'))

		return np.sum(valid_loss)/n_samples



	def valid(self):
		self.model.eval()
		valid_loss = []
		with torch.no_grad():
			n_samples=0
			for idx, (index, X,Y,abnormality) in enumerate(self.valid_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				c= self.model(X)

				loss_dict= reject(self.c, c, Y, abnormality, 0, self.cfg['q'])
				loss= torch.sum((c-self.c)**2, dim=1)
				ab_loss = (loss+1e-12)**-1
				loss = torch.where(abnormality==0, loss, ab_loss)

				mask = torch.ones_like(loss)
				mask2 = mask[torch.where(abnormality==0)]
				mask2[loss_dict['q_indices']]=0
				mask[torch.where(abnormality==0)] = mask2
				n_samples += torch.sum(mask).detach().to('cpu')
				loss = (loss*mask).sum()

				valid_loss.append(loss.detach().to('cpu'))

		return np.sum(valid_loss)/n_samples



	def test(self):
		test_time=[]
		gt_list=[]
		score_list=[]
		self.model.eval()
		with torch.no_grad():
			start = time.time()
			for idx, (index, X,Y,abnormality) in enumerate(self.test_loader):
				X=X.to(self.device)

				c= self.model(X)
				dist= torch.sum((c-self.c)**2, dim=1).detach()
				gt_list +=list(Y.numpy())
				score_list +=list(dist.to('cpu').numpy())
			test_time.append(time.time()-start)

		fpr, tpr, _= roc_curve(gt_list, score_list)
		AUROC = auc(fpr, tpr)
		log(f"AUROC: {AUROC}")
		
		return np.sum(test_time)*1000, AUROC

	def ae_test(self):
		self.ae_model.eval()
		test_time=[]
		gt_list=[]
		score_list=[]
		with torch.no_grad():
			start = time.time()
			for idx, (index, X,Y,abnormality) in enumerate(self.test_loader):
				X=X.to(self.device)

				recon_res= self.ae_model(X)
				X_hat = recon_res['output']
				dist= torch.sum((X-X_hat)**2, dim=(1,2,3))
				gt_list +=list(Y.numpy())
				score_list +=list(dist.to('cpu').numpy())
			test_time.append(time.time()-start)
		

		fpr, tpr, _= roc_curve(gt_list, score_list)
		AUROC = auc(fpr, tpr)
		log(f"AUROC: {AUROC}")
		
		return np.sum(test_time)*1000, AUROC

