import numpy as np
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
from sklearn.metrics import roc_curve, auc

from utils.log import log 
from utils.rejection import reject
from trainer.T_Com import T_Com
from utils.entropy_loss import EntropyLossEncap 

class T_MemAE(T_Com):
	def __init__(self, train_loader, valid_loader, test_loader, cfg):
		super().__init__(train_loader, valid_loader, test_loader, cfg)
		self.alpha= 0.0002
		self.entropy_loss_fn=EntropyLossEncap().to(self.device)

	def train(self):
		normal_ratio_list= []
		abnormal_ratio_list= []
		valid_loss_list=[]
		train_time = []
		auroc_list=[]

		for self.epoch in tqdm(range(1,self.epochs+1)):
			# Train

			start = time.time()
			n_train_loss=[]
			ab_train_loss=[]


			self.model.train()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.model(X)
				X_hat = recon_res['output']
				att_w = recon_res['att']

				loss_dict= reject(X, X_hat, Y, abnormality, 0, self.cfg['q'])
				idx = index[loss_dict['q_indices']]

				c = recon_res['latent'].detach()
				recon_loss = self.loss(X, X_hat, abnormality, loss_dict['q_indices'], 0, self.cfg['s'], self.cfg['a'])
				rl= recon_loss
				recon_loss = recon_loss.mean()

				entropy_loss = self.entropy_loss_fn(att_w).mean()

				loss=recon_loss+entropy_loss*self.alpha


				self.optim.zero_grad()
				loss.backward()
				self.optim.step()

			train_time.append(time.time()-start)


			valid_loss = self.valid()
			log(f"Valid Loss: {valid_loss}")
			valid_loss_list.append(valid_loss)

			self.save(self.model.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}.tar")


		valid_loss_list = np.array(valid_loss_list)
		best_epoch = np.argmin(valid_loss_list) +1
		self.load(f"./ckpt/{self.cfg['model']}/model/{self.cfg['model']}_{str(best_epoch)}.tar")
		log(f"Best Epoch: {best_epoch}")
	

		return np.sum(train_time)*1000

	def valid(self):
		self.model.eval()
		valid_loss = []
		with torch.no_grad():
			total_loss= []
			normal_loss= []
			normal_ct_loss= []
			abnormal_loss= []
			abnormal_ct_loss= []
			n_samples=0
			for idx, (index, X,Y,abnormality) in enumerate(self.valid_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				recon_res= self.model(X)
				X_hat = recon_res['output']
				att_w = recon_res['att']

				loss_dict= reject(X, X_hat, Y, abnormality, 0, self.cfg['q'])
				recon_loss = torch.sum((X-X_hat)**2, dim=(1,2,3)).detach()

				mask = torch.ones_like(recon_loss)
				mask[loss_dict['q_indices']] = 0
				n_samples += torch.sum(mask).detach().to('cpu')
				loss = (recon_loss*mask).sum()

				valid_loss.append(loss.detach().to('cpu'))

		return np.sum(valid_loss)/n_samples

	def test(self):
		test_time=[]
		gt_list=[]
		score_list=[]
		self.model.eval()
		with torch.no_grad():
			start = time.time()
			for idx, (index, X,Y,abnormality) in enumerate(self.test_loader):
				X=X.to(self.device)

				recon_res= self.model(X)
				X_hat = recon_res['output']
				dist= torch.sum((X-X_hat)**2, dim=(1,2,3)).detach()
				gt_list +=list(Y.numpy())
				score_list +=list(dist.to('cpu').numpy())
			test_time.append(time.time()-start)
		

		fpr, tpr, _= roc_curve(gt_list, score_list)
		AUROC = auc(fpr, tpr)
		log(f"AUROC: {AUROC}")
		
		return np.sum(test_time)*1000, AUROC

