import torch
import torch.nn as nn
import torchinfo
import numpy as np

from model.Conv import *
from model.AE import *
from model.MemAE import MemAE
from model.DSAD import DSAD 
from model.RVAEBFA import *
from utils.log import log
from utils.adaptive import AdaptiveLossFunction

class T_Com:
	def __init__(self, train_loader, valid_loader, test_loader, cfg):
		#Configuration
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.test_loader = test_loader
		self.pre_loader = None
		self.cfg=cfg
		self.lr = self.cfg['lr']
		self.epochs = self.cfg['epochs']
		self.batch=self.cfg['batch']
		self.eps= np.finfo(np.float32).eps

		if self.cfg['device'] == "cuda":
			if torch.cuda.is_available():
				self.device = torch.device("cuda") 
			else:
				raise NotImplementedError("Device is set to Cuda, but cuda is not available()")
		else:
			self.device = torch.device('cpu')


		# Dataset Dimension
		if "MNIST" in self.cfg['dataset']:
			self.height=self.width=28
			self.channel=1
			self.mem_dim = 100
			self.shrink_thres = 0.0025
			self.rep_dim= 32
			self.c=64
			self.h = self.w = 3
		else:
			self.height=self.width=32
			self.channel=3
			self.mem_dim = 500
			self.shrink_thres = 0.0025
			self.rep_dim= 128
			self.c=256
			self.h = self.w = 2
		self.fea_dim = self.c*self.h*self.w


		# Model Selection
		weight_decay=0
		if self.cfg['model'] == "AE":
			self.model = Autoencoder(self.channel).to(self.device)
			if self.cfg['loss'] == "GA":
				self.adaptive = AdaptiveLossFunction(num_dims=1, float_dtype=np.float32, device='cuda:0')
				self.optim = torch.optim.Adam(list(self.model.parameters())+list(self.adaptive.parameters()), lr=self.lr, weight_decay=weight_decay)
			else:
				self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=weight_decay)

		elif self.cfg['model'] == "MemAE":
			self.model = MemAE(self.channel, mem_dim=self.mem_dim, shrink_thres=self.shrink_thres).to(self.device)
			if self.cfg['loss'] == "GA":
				self.adaptive = AdaptiveLossFunction(num_dims=1, float_dtype=np.float32, device='cuda:0')
				self.optim = torch.optim.Adam(list(self.model.parameters())+list(self.adaptive.parameters()), lr=self.lr, weight_decay=weight_decay)
			else:
				self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=weight_decay)


		elif self.cfg['model'] == "DSAD" or self.cfg['model'] == "DSVDD":
			weight_decay=1e-6
			self.ae_epochs = self.cfg['ae_epochs']
			self.ae_lr= self.cfg['ae_lr']

			self.ae_model = DSADAutoencoder(self.channel, self.rep_dim, weight_init=True).to(self.device)
			self.model = DSAD(self.channel, self.rep_dim).to(self.device)

			if self.cfg['loss'] == "GA":
				self.adaptive = AdaptiveLossFunction(num_dims=1, float_dtype=np.float32, device='cuda:0')
				self.ae_optim = torch.optim.Adam(list(self.ae_model.parameters())+list(self.adaptive.parameters()), lr=self.ae_lr, weight_decay=weight_decay)
				self.optim = torch.optim.Adam(list(self.model.parameters())+list(self.adaptive.parameters()), lr=self.lr, weight_decay=weight_decay)
			else:
				self.ae_optim = torch.optim.Adam(self.ae_model.parameters(), lr=self.ae_lr, weight_decay=weight_decay)
				self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=weight_decay)


		elif self.cfg['model'] == "ITSR":
			if self.channel ==1:
				self.enc = MNISTEncoder(self.channel).to(self.device)
				self.dec = MNISTDecoder(self.channel).to(self.device)
			else:
				self.enc = CIFAREncoder(self.channel).to(self.device)
				self.dec = CIFARDecoder(self.channel).to(self.device)
			self.dis = Discriminator(self.channel).to(self.device)

			self.reg_lr= self.cfg['reg_lr']

			self.optim_enc = torch.optim.Adam(self.enc.parameters(), lr=self.lr)
			self.optim_dec = torch.optim.Adam(self.dec.parameters(), lr=self.lr)

			self.optim_reg_enc = torch.optim.Adam(self.enc.parameters(), lr=self.reg_lr)
			self.optim_dis = torch.optim.Adam(self.dis.parameters(), lr=self.reg_lr)

		elif self.cfg['model'] == "RVAEBFA":
			self.model = RVAEBFA(self.channel, rep_dim=32).to(self.device)
			self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=weight_decay)

		elif self.cfg['model'] == "NCAE_UAD":
			self.lr_milestones= tuple([20,50,70,90])

			self.model= DSADAutoencoder(self.channel, self.rep_dim).to(self.device)
			self.D_l = Discriminator_L(self.channel, self.rep_dim).to(self.device)

			if self.channel==1:
				self.D_s = MNIST_Discriminator_S(self.rep_dim).to(self.device)
			else:
				self.D_s = CIFAR10_Discriminator_S(self.rep_dim).to(self.device)

		else:
			raise ValueError(f"{self.cfg['model']} Unknown model name")



		# Loss Selection
		if self.cfg['loss'] == "MSE":
			self.loss = self.MSE_loss
		elif self.cfg['loss'] == "pseudo_Huber":
			self.loss = self.pseudo_Huber_loss
		elif self.cfg['loss'] == "LOE":
			self.loss = self.LOE_loss
		elif self.cfg['loss'] == "GA":
			self.loss = self.GA_loss
		elif self.cfg['loss'] == "H-loss":
			self.loss = self.H_loss

		
	def save(self, params, file_name):
		torch.save(params, f"./ckpt/{self.cfg['model']}/model/{file_name}")
		torch.save(params, f"{self.cfg['backup_path']}/model/{file_name}")

	def load(self, path):
		self.model.load_state_dict(torch.load(path))

	def soft_reject(func):
		def wrapper(self, X, X_hat, abnormality, q_indices, target, Ts, *args):
			
			tmp_loss = func(self, X, X_hat, abnormality, q_indices, target, *args)

			tmp_loss[q_indices] = Ts*tmp_loss[q_indices]
			loss = torch.zeros((X_hat.size()[0])).to(X.device)
			loss[torch.where(abnormality==target)] = tmp_loss
			return loss
		return wrapper

	@soft_reject
	def MSE_loss(self, X, X_hat, abnormality, q_indices, target, *arg):
		indices = torch.where(abnormality==target)
		dim = tuple([i for i in range(1,X_hat.dim())])
		dist= torch.sum((X-X_hat)**2, dim=dim)
		mse_dist = dist[indices]

		return mse_dist



	@soft_reject
	def GA_loss(self, X, X_hat, abnormality, q_indices, target, *arg):
		indices = torch.where(abnormality==target)
		dim = tuple([i for i in range(1,X_hat.dim())])
		e = torch.sqrt(torch.sum((X-X_hat)**2, dim=dim).unsqueeze(1))
		dist = self.adaptive.lossfun(e).squeeze(1)
		dist = dist[indices]
		return dist


	@soft_reject
	def pseudo_Huber_loss(self, X, X_hat, abnormality, q_indices, target, *arg):
		indices = torch.where(abnormality==target)
		dim = tuple([i for i in range(1,X_hat.dim())])
		e = torch.sum((X-X_hat)**2, dim=dim)
		dist = torch.sqrt(e+1)-1
		dist = dist[indices]
		return dist

	@soft_reject
	def LOE_loss(self, X, X_hat, abnormality, q_indices, target, *arg):
		indices = torch.where(abnormality==target)
		dim = tuple([i for i in range(1,X_hat.dim())])
		dist= torch.sum(torch.abs(X-X_hat), dim=dim)

		n_dist = dist[indices]

		machine_epsilon = torch.tensor(np.finfo(np.float32).eps).to('cuda')
		a_dist = 1/(dist[indices]+machine_epsilon)

		ct=0.1
		alpha=0.5
		score=n_dist-a_dist

		_, idx_n = torch.topk(score, int(dist.shape[0]*(1-ct)), largest=False, sorted=False)
		_, idx_a = torch.topk(score, int(dist.shape[0]*ct), largest=True, sorted=False)
		dist=torch.cat([n_dist[idx_n], (1-alpha)*n_dist[idx_a] + alpha*a_dist[idx_a]],0)


		return dist



	@soft_reject
	def H_loss(self, X, X_hat, abnormality, q_indices, target, *args):
		dim = tuple([i for i in range(1,X_hat.dim())])
		indices = torch.where(abnormality==target)
		dist= torch.sum((X-X_hat)**2, dim=dim)

		mse_dist = dist[indices]
		mask = torch.zeros_like(mse_dist)
		mask[q_indices] = 1

		median = torch.median(mse_dist)
		mad = torch.median(torch.abs(mse_dist-median))
		z_score = 0.6745*(mse_dist-median)/mad

		MAX=max(3.5,torch.max(torch.abs(z_score)))
		MIN=0


		norm_score = (z_score-MIN)/(MAX-MIN)

		norm_score[torch.where(z_score<=0)]=0


		q=2-(2-args[0])*norm_score**2
		q[torch.where(z_score<=0)] = 2
		a=torch.clamp(q,min=args[0],max=2)
	


		c=np.sqrt(0.5)

		machine_epsilon = torch.tensor(np.finfo(np.float32).eps).to('cuda')
		beta_safe = torch.max(machine_epsilon, torch.abs(a- 2.))
		alpha_safe = torch.where(a>= 0, torch.ones_like(a),
							 -torch.ones_like(a)) * torch.max(
								 machine_epsilon, torch.abs(a))

		other= (beta_safe/ alpha_safe) * (
			torch.pow(mse_dist/(c**2)/ beta_safe + 1., 0.5 * a) - 1.)
		loss = torch.where(a==2, 0.5*mse_dist/(c**2),
						torch.where(a==0, torch.log1p(0.5*mse_dist/(c**2)),
								other)
						)
		return loss 


