import numpy as np
import random
import torch
import torch.nn as nn
import torchinfo
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
from sklearn.metrics import roc_curve, auc
from torch.autograd import Variable
import torch.nn.functional as F

from utils.log import log 
from trainer.T_Com import T_Com

import torch.optim as optim
from model.AE import *
from model.Conv import *

from utils.make_data import make_data
from utils.dataset import Dataset

class T_NCAE_UAD(T_Com):
	def __init__(self, train_loader, valid_loader, test_loader, cfg):
		super().__init__(train_loader, valid_loader, test_loader, cfg)
		self.std=0.1
		self.lam=0.1
		self.topk = int(self.batch*0.1)


	def load(self, path):
		path, name = path.rsplit("/", 1)
		name, ext = name.rsplit("_", 1)
		net.load_state_dict(torch.load(f"{path}/{name}_AE.tar"))
		self.D_l.load_state_dict(torch.load(f"{path}/{name}_D_l.tar"))
		self.D_s.load_state_dict(torch.load(f"{path}/{name}_D_s.tar"))

	def train(self):





		# Initial setup for Adversarial learning
		real_label = 1
		fake_label = 0

		# Set loss
		criterion = nn.MSELoss(reduction='none')
		#criterion_D = nn.BCELoss()
		criterion_D = nn.CrossEntropyLoss()

		# Set device

		self.model= DSADAutoencoder(self.channel, self.rep_dim).to(self.device)
		netD_l = self.D_l.to(self.device)
		netD_S = self.D_s.to(self.device)

		self.mu = self.init_center_c(self.train_loader, self.model.enc)
		self.std_mtx = torch.ones(self.mu.size(),device=self.device)*self.std
		self.idt_mtx = torch.ones(self.mu.size(),device=self.device)



		criterion = criterion.to(self.device)

		# Set optimizer (Adam optimizer for now)
		self.weight_decay=0.5e-6
		optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
		optimizer_d = optim.Adam(self.model.dec.parameters(), lr=0.0005, betas=(0.5, 0.999))
		optimizer_l = optim.Adam(netD_l.parameters(), lr=0.0005, betas=(0.5, 0.999))
		optimizer_s = optim.Adam(netD_S.parameters(), lr=0.0005, betas=(0.5, 0.999))

		# Set learning rate scheduler
		scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)
		scheduler_d = optim.lr_scheduler.MultiStepLR(optimizer_d, milestones=self.lr_milestones, gamma=0.1)
		scheduler_l = optim.lr_scheduler.MultiStepLR(optimizer_l, milestones=self.lr_milestones, gamma=0.1)
		scheduler_s = optim.lr_scheduler.MultiStepLR(optimizer_s, milestones=self.lr_milestones, gamma=0.1)

		# Training
		start_time = time.time()
		self.model.train()
		for self.epoch in tqdm(range(1,self.epochs+1)):
			epoch_loss = 0.0
			n_batches = 0
			epoch_start_time = time.time()
			mu_lr = scheduler.get_last_lr()[0]
			for _s, data in enumerate(self.train_loader):

				idx, inputs, target, _ = data
				inputs = inputs.to(self.device)
				_refine_input = inputs.detach().clone()
				_inputs = inputs.detach().clone()
				gan_label = torch.LongTensor(inputs.size()[0]).fill_(0).cuda()

				###########################
				# (1) Update Generator network (G)
				###########################
				optimizer_d.zero_grad()
				# Original GAN loss
				for _ in range(inputs.size()[0]):
					if _==0:
						noise = torch.normal(self.mu,self.std_mtx).view(1,-1)
					else:
						noise = torch.cat((noise,torch.normal(self.mu,self.std_mtx).view( 1,-1)),0)
				#noise = torch.FloatTensor(inputs.size()[0], 32).normal_(self.mu, self.std).cuda()
				noise = Variable(noise)
				fake = self.model.dec(noise)
				targetv = Variable(gan_label.fill_(real_label))
				output = netD_S(fake)
				output = output.squeeze()
				errG = criterion_D(output, targetv)
				errG.backward()
				errG_value = errG.item()
				optimizer_d.step()

				###########################
				# (1) Update D_S network  #
				###########################
				gan_label.fill_(real_label)
				targetv = Variable(gan_label)
				optimizer_s.zero_grad()
				output = netD_S(inputs)
				output = output.squeeze()
				errD_S_real = criterion_D(output, targetv)
				errD_S_real.backward()

				# noise = torch.FloatTensor(inputs.size()[0], 32).normal_(self.mu,1).cuda()
				# noise = Variable(noise)
				fake = self.model.dec(noise)
				targetv = Variable(gan_label.fill_(fake_label))
				output = netD_S(fake.detach())
				output = output.squeeze()
				errD_S_fake = criterion_D(output, targetv)
				errD_S_fake.backward()
				errD_S_value = errD_S_real.item() + errD_S_fake.item()
				optimizer_s.step()

				###########################
				# (1) Update Encoder network (f) + Generator (G)
				###########################
				optimizer.zero_grad()
				# Original GAN loss
				targetv = Variable(gan_label.fill_(real_label))
				output = netD_l(noise)
				output = output.squeeze()
				errE = criterion_D(output, targetv)
				# errE.backward()

				clatent = self.model.enc(fake.detach())
				fake_re = self.model.dec(clatent.detach())

				res = self.model(inputs)
				rec = res['output']
				latent = res['latent']

				l_norm = latent.norm(p=2,dim=1,keepdim=True)
				c_norm = clatent.norm(p=2,dim=1,keepdim=True)
				latent_norm = latent.div(l_norm.expand_as(latent))
				clatent_norm = clatent.div(c_norm.expand_as(clatent))
				_gamma = torch.mean(latent_norm.mm(clatent_norm.t()),1)


				_, index_sorted = torch.sort(_gamma, dim=0, descending=False)
				rec_out = F.softmax(netD_S(rec))
				#pdb.set_trace()
				#inputs[index_sorted[0:self.topk]] = fake[index_sorted[0:self.topk]]
				rec_loss_1 = criterion(rec[index_sorted[self.topk:]], inputs[index_sorted[self.topk:]])
				rec_loss_2 = criterion(rec[index_sorted[0:self.topk]], fake[index_sorted[0:self.topk]])
				rec_loss = torch.cat((rec_loss_1,rec_loss_2),0)

				_refine_input[index_sorted[0:self.topk]] = fake[index_sorted[0:self.topk]]
				loss = torch.mean(rec_loss)
				total_loss = errE + self.lam * loss
				total_loss.backward()
				optimizer.step()

				###########################
				# (1) Update D_L network	#
				###########################
				gan_label.fill_(real_label)
				targetv = Variable(gan_label)
				optimizer_l.zero_grad()
				latent = self.model.enc(inputs)
				output = netD_l(latent)
				output = output.squeeze()
				errD_l_real = criterion_D(output, targetv)
				errD_l_real.backward()


				for _ in range(inputs.size()[0]):
					if _==0:
						noise = torch.normal(self.mu,self.idt_mtx).view(1,-1)
					else:
						noise = torch.cat((noise,torch.normal(self.mu,self.idt_mtx).view(1,-1)),0)
				#noise = torch.FloatTensor(inputs.size()[0], 32).normal_(self.mu, self.std).cuda()
				noise = Variable(noise)
				targetv = Variable(gan_label.fill_(fake_label))
				output = netD_l(noise.detach())
				output = output.squeeze()
				errD_l_fake = criterion_D(output, targetv)
				errD_l_fake.backward()
				errD_l_value = errD_l_real.item() + errD_l_fake.item()
				optimizer_l.step()

				###########################
				# (1) Update mu			  #
				###########################
				self.mu = self.mu-mu_lr*(torch.mean(self.mu-latent))
				epoch_loss += total_loss.item()

			self.D_l =  netD_l
			self.D_s =  netD_S
			self.save(self.model.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}_AE.tar")
			self.save(self.D_l.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}_D_l.tar")
			self.save(self.D_s.state_dict(), f"{self.cfg['model']}_{str(self.epoch)}_D_s.tar")


			scheduler.step()
			scheduler_d.step()
			scheduler_l.step()
			scheduler_s.step()

			epoch_train_time = time.time() - epoch_start_time
			self.train_time = time.time() - start_time

		return self.train_time

	def test(self):
		self.model.eval()
		test_time=[]
		gt_list=[]
		score_list=[]
		cri=nn.MSELoss(reduction='none')
		with torch.no_grad():
			start = time.time()
			for idx, (index, X,Y,abnormality) in enumerate(self.test_loader):
				X=X.to(self.device)
				recon_res= self.model(X)
				X_hat = recon_res['output']

				dist= torch.sum((X-X_hat)**2, dim=(1,2,3))
				gt_list +=list(Y.numpy())
				score_list +=list(dist.to('cpu').numpy())
			test_time.append(time.time()-start)
		

		fpr, tpr, _= roc_curve(gt_list, score_list)
		AUROC = auc(fpr, tpr)
		log(f"AUROC: {AUROC}")
		
		return np.sum(test_time)*1000, AUROC


	def init_center_c(self, train_loader, encoder, eps=0.1):
		"""Initialize hypersphere center c as the mean from an initial forward pass on the data."""
		n_samples = 0
		c = torch.zeros(encoder.rep_dim, device=self.device)

		encoder.eval()
		with torch.no_grad():
			for data in train_loader:

				# get the inputs of the batch
				_, inputs, _, _ = data
				inputs = inputs.to(self.device)
				outputs = encoder(inputs)
				n_samples += outputs.shape[0]
				c += torch.sum(outputs, dim=0)

		c /= n_samples

		# If c_i is too close to 0, set to +-eps. Reason: a zero unit can be trivially matched with zero weights.
		c[(abs(c) < eps) & (c < 0)] = -eps
		c[(abs(c) < eps) & (c > 0)] = eps

		return c
