import numpy as np
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
from sklearn.metrics import roc_curve, auc
from sklearn.svm import OneClassSVM

from utils.log import log 
from trainer.T_Com import T_Com

class T_ITSR(T_Com):
	def __init__(self, train_loader, valid_loader, test_loader, cfg):
		super().__init__(train_loader, valid_loader, test_loader, cfg)
		self.pretraining_epochs=50
		self.refinement_epochs=10
		self.reptition=10
		self.nu = 0.02

		self.ocsvm = OneClassSVM(nu=self.nu)

	def load(self, path):
		path, name = path.rsplit("/", 1)
		name, ext = name.rsplit("_", 1)
		self.enc.load_state_dict(torch.load(f"{path}/{name}_ENC.tar"))
		self.dec.load_state_dict(torch.load(f"{path}/{name}_DEC.tar"))
		self.dis.load_state_dict(torch.load(f"{path}/{name}_DIS.tar"))

	def train(self):

		train_time = []
		valid_loss_list=[]
		fpr_list=[]
		tpr_list=[]
		fpr=tpr=0

		w = torch.ones((self.cfg['N_train'], 1)).to(self.device)
		for self.epoch in tqdm(range(1,self.pretraining_epochs+self.refinement_epochs*self.reptition+1)):
			# Train
			z_list = []

			start = time.time()


			self.enc.train()
			self.dec.train()
			self.dis.train()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				#Recon loss
				z = self.enc(X)
				X_hat = self.dec(z)

				loss = w[index] * torch.sum((X-X_hat)**2, dim=(1,2,3))
				loss = loss.mean()

				self.optim_enc.zero_grad()
				self.optim_dec.zero_grad()
				loss.backward()
				self.optim_enc.step()
				self.optim_dec.step()

				# Discriminator
				z_real = torch.randn(X.size()[0], self.fea_dim).to(X.device)
				z_d_real = self.dis(z_real)	

				z_fake = self.enc(X).detach()
				z_d_fake = self.dis(z_fake)	
				
				loss = -torch.mean(torch.log(z_d_real + self.eps) + torch.log(1-z_d_fake + self.eps))
				
				self.optim_dis.zero_grad()
				loss.backward()
				self.optim_dis.step()

				# Generator
				z_fake = self.enc(X)
				z_d_fake = self.dis(z_fake)

				loss = -torch.mean(torch.log(z_d_fake+self.eps))
	
				self.optim_reg_enc.zero_grad()
				loss.backward()
				self.optim_reg_enc.step()

			fpr_list.append(fpr)
			tpr_list.append(tpr)

			if self.epoch >= self.pretraining_epochs and self.epoch%self.refinement_epochs==1:
				log(f"{self.epoch}: Detection and Refinement")

				z = torch.ones((self.cfg['N_train'], self.fea_dim))
				y = torch.zeros((self.cfg['N_train']))
				self.enc.eval()
				with torch.no_grad():
					for idx, (index, X, Y, _) in enumerate(self.train_loader):
						X=X.to(self.device)
						z[index] = self.enc(X).to('cpu').detach().view(X.size()[0],-1)
						y[index] = Y.float()
				y=y.to(self.device)

				w = self.ocsvm.fit_predict(z)
				w = torch.from_numpy(w)
				w = w.view(w.size()[0], 1).to(self.device)
				w[torch.where(w==-1)] = 0

				indices = torch.where(w<1)[0]
				tpr = torch.where(y[indices]==1, 1, 0).sum().to('cpu').item()
				fpr = torch.where(y[indices]==0, 1, 0).sum().to('cpu').item()

			train_time.append(time.time()-start)


		for self.epoch in tqdm(range(1,self.epochs+1)):
			# Train
			Nab=0

			start = time.time()


			self.enc.train()
			self.dec.train()
			self.dis.train()
			for idx, (index, X,Y,abnormality) in enumerate(self.train_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				#Recon loss
				z = self.enc(X)
				X_hat = self.dec(z)

				loss = w[index] * torch.sum((X-X_hat)**2, dim=(1,2,3))
				loss = loss.mean()

				self.optim_enc.zero_grad()
				self.optim_dec.zero_grad()
				loss.backward()
				self.optim_enc.step()
				self.optim_dec.step()

				# Discriminator
				z_real = torch.randn(X.size()[0], self.fea_dim).to(X.device)
				z_d_real = self.dis(z_real)	

				z_fake = self.enc(X).detach()
				z_d_fake = self.dis(z_fake)	
				
				loss = -torch.mean(torch.log(z_d_real + self.eps) + torch.log(1-z_d_fake + self.eps))
				
				self.optim_dis.zero_grad()
				loss.backward()
				self.optim_dis.step()

				# Generator
				z_fake = self.enc(X)
				z_d_fake = self.dis(z_fake)

				loss = -torch.mean(torch.log(z_d_fake+self.eps))
	
				self.optim_reg_enc.zero_grad()
				loss.backward()
				self.optim_reg_enc.step()

				Nab += torch.where(Y==1,1,0).sum().to('cpu')

			fpr_list.append(fpr)
			tpr_list.append(tpr)

			dist = torch.ones((self.cfg['N_train']))
			y = torch.zeros((self.cfg['N_train']))
			self.enc.eval()
			self.dec.eval()
			with torch.no_grad():
				for idx, (index, X, Y, _) in enumerate(self.train_loader):
					X=X.to(self.device)

					X_hat = self.dec(self.enc(X))
					d = torch.sum((X-X_hat)**2, dim=(1,2,3))
					dist[index] = d.to('cpu').detach()
					y[index] = Y.float()
			y=y.to(self.device)

			indices= torch.where(w<1)[0]
			dist = dist[indices]
			q = torch.quantile(dist, 0.8)
			tmp_w = w[indices]
			tmp_w[torch.where(dist > q)] = -1
			w[indices] = tmp_w

			indices = torch.where(w<1)[0]
			tpr = torch.where(y[indices]==1, 1, 0).sum().to('cpu').item()
			fpr = torch.where(y[indices]==0, 1, 0).sum().to('cpu').item()

			train_time.append(time.time()-start)

			#Valid 
			valid_loss = self.valid()
			log(f"Valid Loss: {valid_loss}")
			valid_loss_list.append(valid_loss)


			self.save(self.enc.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}_ENC.tar")
			self.save(self.dec.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}_DEC.tar")
			self.save(self.dis.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}_DIS.tar")

		valid_loss_list = np.array(valid_loss_list)
		best_epoch = np.argmin(valid_loss_list) +1
		self.load(f"./ckpt/{self.cfg['model']}/model/{self.cfg['model']}_{str(best_epoch)}_ENC.tar")
		log(f"Best Epoch: {best_epoch}")

		log("True Positive (Abnormal)")
		log(f"\n{tpr_list}")
		log("False Positive (Normal)")
		log(f"\n{fpr_list}")

		plt.figure(0)
		plt.plot(tpr_list, label='Number of Rejected Abnormal Data')
		plt.plot(fpr_list, label='Number of Rejected Normal Data')
		plt.plot([Nab for i in range(len(fpr_list))], '--', label='Number of Abnormal Data')

		plt.title("Prediction as Abnormal")
		plt.legend(loc='best')
		plt.savefig(f"{self.cfg['backup_path']}/z_score.png")
		plt.savefig(f"./ckpt/{self.cfg['model']}/z_score.png")

		plt.figure(1)
		fpr_list = np.array(fpr_list)
		tpr_list = np.array(tpr_list)
		n_list = self.cfg['Nn']-fpr_list
		ab_list = self.cfg['Nab']-tpr_list
		ratio_list = n_list/(n_list + ab_list)
		plt.plot(ratio_list, label='Ratio of Normal data with Rejection')
		plt.plot([self.cfg['Nn']/(self.cfg["Nn"] + self.cfg['Nab']) for i in range(len(n_list))], label='Ratio of Normal data in Training samples')
		plt.title("Normal Data Ratio in Training Samples")
		plt.legend(loc='best')
		plt.savefig(f"{self.cfg['backup_path']}/ratio.png")
		plt.savefig(f"./ckpt/{self.cfg['model']}/ratio.png")


		return np.sum(train_time)*1000

	def valid(self):
		self.enc.eval()
		self.dec.eval()
		self.dis.eval()

		valid_loss = []
		with torch.no_grad():
			n_samples=0

			z = torch.ones((self.cfg['N_train'], self.fea_dim))
			y = torch.zeros((self.cfg['N_train']))
			for idx, (index, X, Y, _) in enumerate(self.train_loader):
				X=X.to(self.device)
				z[index] = self.enc(X).to('cpu').detach().view(X.size()[0],-1)
				y[index] = Y.float()
			y=y.to(self.device)

			w = self.ocsvm.fit_predict(z)
			w = torch.from_numpy(w)
			w = w.view(w.size()[0], 1).to(self.device)
			w[torch.where(w==-1)] = 0


			for idx, (index, X,Y,abnormality) in enumerate(self.valid_loader):
				X=X.to(self.device)
				Y=Y.to(self.device)
				abnormality=abnormality.to(self.device)

				z = self.enc(X)
				X_hat = self.dec(z)
				recon_loss = w[index] * torch.sum((X-X_hat)**2, dim=(1,2,3))

				w_anomaly = w[index]
				indices = torch.where(w_anomaly < 1)

				mask = torch.ones_like(recon_loss)
				mask[indices] = 0
				n_samples += torch.sum(mask).detach().to('cpu')
				loss = (recon_loss*mask).sum()

				valid_loss.append(loss.detach().to('cpu'))

		return np.sum(valid_loss)/n_samples


	def test(self):
		self.enc.eval()
		self.dec.eval()
		self.dis.eval()

		test_time=[]
		gt_list=[]
		score_list=[]

		with torch.no_grad():
			start = time.time()

			for idx, (index, X,Y,abnormality) in enumerate(self.test_loader):
				X=X.to(self.device)

				z = self.enc(X)
				X_hat = self.dec(z)
				dist= torch.sum((X-X_hat)**2, dim=(1,2,3))
				gt_list +=list(Y.numpy())
				score_list +=list(dist.to('cpu').numpy())
			test_time.append(time.time()-start)
		

		fpr, tpr, _= roc_curve(gt_list, score_list)
		AUROC = auc(fpr, tpr)
		log(f"AUROC: {AUROC}")
		
		return np.sum(test_time)*1000, AUROC

