'''
PCL: Proxy-based Contrastive Learning for Domain Generalization (CVPR'22)
'''
import numpy as np

import torch
import torch.nn as nn

import torch.nn.functional as F

class ProxyPLoss(nn.Module):
	'''
	But in his original implementation, he doesn't compute the negative pair between anchor sample and other proxy.
	contrastive loss = positive pair of sample and proxy / 
	    			   (positive pair of sample and proxy + negative pair of samples.)
	'''
	
	def __init__(self, num_classes, scale):
		super(ProxyPLoss, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, proxy):
		'''
		feature: (N, dim)
		proxy: (C, dim)
		'''
		feature = F.normalize(feature, p=2, dim=1)
		pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred = torch.masked_select(pred.transpose(1, 0), label)    # (N)   positive pair
		pred = pred.unsqueeze(1)
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)  # sample wise positive pair index
		
		index_label = torch.LongTensor([i for i in range(feature.shape[0])])  # generate index label
		index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0)  # get index matrix
		
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < 1e-6, -np.inf)
		
		logits = torch.cat([pred, feature], dim=1)  # (N, C+N)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
		
		return loss

class ProxyPLoss2(nn.Module):
	'''
	Revise his original implementation, compute the negative pair between anchor sample and other proxy.
	contrastive loss = positive pair of sample and proxy / 
	    			   (positive, negative pair of sample and proxy + negative pair of samples.)
	'''
	
	def __init__(self, num_classes, scale):
		super(ProxyPLoss2, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, proxy):
		'''
		feature: (N, dim)
		proxy: (C, dim)
		'''
		feature = F.normalize(feature, p=2, dim=1)
		pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
		pred_p = pred_p.unsqueeze(1)
		pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)  # sample wise positive pair index
		
		# index_label = torch.LongTensor([i for i in range(feature.shape[0])])  # generate index label
		# index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0)  # get index matrix
		
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < 1e-6, -np.inf)
		
		logits = torch.cat([pred_p, pred_n, feature], dim=1)  # (N, C+N)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
		
		return loss

class MPCLoss(nn.Module):
	"""
	v=0:
	contrastive loss = -log ( exp(0.5* (positive pair of sample and proxy + positive pair of samples and Mproxy)) / 
	    			   		  exp(0.5* (positive, negative pair of sample and proxy + positive, negative pair of sample and Mproxy)) + sum(exp(negative pair of samples.)) )
    v=1:
	contrastive loss = -log ( (exp(positive pair of sample and proxy) + exp(positive pair of samples and Mproxy)) / 
	    			   		  (exp(positive pair of sample and proxy) + exp(positive pair of samples and Mproxy) + sum(exp(negative pair of samples.))) )
	"""
	def __init__(self, num_classes, scale):
		super(MPCLoss, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, proxy, Mproxy, mweight=1, v=0):
		'''
		feature: (N, dim)
		proxy: (C, dim)
		Mproxy: (C, dim) 
		'''
		feature = F.normalize(feature, p=2, dim=1)
		pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)  w*z(similarity between sample and proxy)
		Mpred = F.linear(feature, F.normalize(Mproxy, p=2, dim=1))  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
		pred_p = pred_p.unsqueeze(1)
		pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		Mpred_p = torch.masked_select(Mpred, label.transpose(1, 0))    # (N)   positive pair
		Mpred_p = Mpred_p.unsqueeze(1)
		Mpred_n = torch.masked_select(Mpred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)  # sample wise positive pair index
		
		# index_label = torch.LongTensor([i for i in range(feature.shape[0])])  # generate index label
		# index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0)  # get index matrix
		
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < 1e-6, -np.inf)
		
		if v == 0:
			logits = torch.cat([0.5*(pred_p + mweight*Mpred_p), 0.5*(pred_n + mweight*Mpred_n), feature], dim=1)  # (N, C+N)
			label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
			loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
		elif v == 1:
			logits = torch.cat([pred_p, Mpred_p, pred_n, Mpred_n, feature], dim=1)  # (N, 2*C+N)
			loss = -torch.log(  ( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*Mpred_p.squeeze()) ) / 
								( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*Mpred_p.squeeze()) + 
         						torch.exp(self.scale*pred_n).sum(dim=1) + mweight*torch.exp(self.scale*Mpred_n).sum(dim=1) +  torch.exp(self.scale*feature).sum(dim=1))  ).mean()
			# loss = -torch.log((torch.exp(self.scale*logits[:,0]) + mweight*torch.exp(self.scale*logits[:,1])) / 
            #          torch.exp(self.scale*logits).sum(dim=1)).mean()
		elif v==2: # only consider old proxy negative pairs
			loss = -torch.log(  ( torch.exp(self.scale*pred_p.squeeze()) ) / 
								( torch.exp(self.scale*pred_p.squeeze()) + 
         						torch.exp(self.scale*pred_n).sum(dim=1) + mweight*torch.exp(self.scale*Mpred_n).sum(dim=1) +  torch.exp(self.scale*feature).sum(dim=1))  ).mean()
		elif v == 3:
			loss = -torch.log(  ( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*Mpred_p.squeeze()) ) / 
								( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*Mpred_p.squeeze()) + 
         						torch.exp(self.scale*pred_n).sum(dim=1) + mweight*torch.exp(self.scale*Mpred_n).sum(dim=1) )  ).mean()
        
		
		return loss


class FPLoss(nn.Module):
	"""
	Forward Prototype contrastive loss:
	(positive center * feature) / (all center * feature)
	"""
	
	def __init__(self, num_classes, scale):
		super(FPLoss, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target):
		'''
		feature: (N, dim)
		'''
		feature = F.normalize(feature, p=2, dim=1)

		onehot_label = torch.eye(len(self.label))[target].cuda()     # (N, C)
		center = torch.matmul(onehot_label.T, feature)
		proxy = center / (1e-8 + onehot_label.sum(dim=0)[:,None])

		pred = F.linear(feature, proxy)  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
		pred_p = pred_p.unsqueeze(1)
		pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		
		logits = torch.cat([pred_p, pred_n], dim=1)  # (N, C)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(logits, dim=1), label)
		
		return loss

class MFPLoss(nn.Module):
	"""
	Forward Prototype contrastive loss:
	((positive center * feature) + (positive old center * feature)) / ((all center * feature) + (all old center * feature))
	"""
	
	def __init__(self, num_classes, scale):
		super(MFPLoss, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, old_proxy=None, mweight=1):
		'''
		feature: (N, dim)
		old_proxy: (C, dim)
		'''
		feature = F.normalize(feature, p=2, dim=1)

		onehot_label = torch.eye(len(self.label))[target].cuda()     # (N, C)
		center = torch.matmul(onehot_label.T, feature)
		proxy = center / (1e-8 + onehot_label.sum(dim=0)[:,None])

		if old_proxy is not None:
			proxy = (proxy + old_proxy) / 2

		pred = F.linear(feature, proxy)  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
		pred_p = pred_p.unsqueeze(1)
		pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		
		# if old_proxy is not None:
		# 	old_pred = F.linear(feature, old_proxy)
		# 	old_pred_p = torch.masked_select(old_pred, label.transpose(1, 0))    # (N)   positive pair
		# 	old_pred_p = old_pred_p.unsqueeze(1)
		# 	old_pred_n = torch.masked_select(old_pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		# 	loss = -torch.log(  ( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*old_pred_p.squeeze()) ) / 
		# 				( torch.exp(self.scale*pred_p.squeeze()) + mweight*torch.exp(self.scale*old_pred_p.squeeze()) + 
		# 				torch.exp(self.scale*pred_n).sum(dim=1) + mweight*torch.exp(self.scale*old_pred_n).sum(dim=1) )  ).mean()
		# else:
		logits = torch.cat([pred_p, pred_n], dim=1)  # (N, C)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(logits, dim=1), label)
		return loss


class WeightPCL(nn.Module):
	'''
	Weight ProxyPLoss2
	use softmax output as weight
	'''
	
	def __init__(self, num_classes, scale):
		super(WeightPCL, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, proxy, pred_logits):
		'''
		feature: (N, dim)
		proxy: (C, dim)
		'''
		feature = F.normalize(feature, p=2, dim=1)
		pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
		pred_p = pred_p.unsqueeze(1)
		pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)  sample wise similarity
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)  # sample wise positive pair index
		
		# index_label = torch.LongTensor([i for i in range(feature.shape[0])])  # generate index label
		# index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0)  # get index matrix
		
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < 1e-6, -np.inf)
		
		logits = torch.cat([pred_p, pred_n, feature], dim=1)  # (N, C+N)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label, reduction='none')  # (N)
    
		pred_logits = nn.Softmax(dim=1)(pred_logits)    # (N, C)
		weight, _ = torch.max(pred_logits, dim=1)       # (N)
		loss = (weight * loss).mean()
				
		return loss


class ProxyOnlyLoss(nn.Module):
	'''
	contrastive loss = positive pair of sample and proxy / 
	    			   (positive, negative pair of sample and proxy)
	'''
	
	def __init__(self, num_classes, scale):
		super(ProxyOnlyLoss, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, proxy):
		'''
		feature: (N, dim)
		proxy: (C, dim)
		'''
		feature = F.normalize(feature, p=2, dim=1)
		pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)  w*z(similarity between sample and proxy)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))   # (C, N) each row indicate which sample is the same class
		pred_p = torch.masked_select(pred, label.transpose(1, 0))    # (N)   positive pair
		pred_p = pred_p.unsqueeze(1)
		pred_n = torch.masked_select(pred, ~label.transpose(1, 0)).view(feature.size(0), -1)  # (N, C-1) negative pair of anchor and proxy
		
		logits = torch.cat([pred_p, pred_n], dim=1)  # (N, C+N)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
		
		return loss


def encoder(args, in_dim, out_dim):
	hidden_size = 256 if args.dataset == 'dg5' else 512
	encoder = nn.Sequential(
		nn.Linear(in_dim, hidden_size),
		nn.BatchNorm1d(hidden_size),
    	nn.ReLU(inplace=True),
		nn.Linear(hidden_size, out_dim)
	)
	return encoder

def fea_proj(args, dim):
	fc_proj = nn.Parameter(torch.FloatTensor(dim, dim))
	fea_proj = nn.Sequential(nn.Linear(dim, dim))
	return fea_proj, fc_proj





class ProxyLoss(nn.Module):
	'''
	pass
	'''

	def __init__(self, scale=1, thres=0.1):
		super(ProxyLoss, self).__init__()
		self.scale = scale
		self.thres = thres

	def forward(self, feature, pred, target):
		feature = F.normalize(feature, p=2, dim=1)  # normalize
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (B, B)
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < self.thres, -np.inf)
		pred = torch.cat([pred, feature], dim=1)  # (N, C+N)

		loss = F.nll_loss(F.log_softmax(self.scale * pred, dim=1), \
		                  target)

		return loss

class PosAlign(nn.Module):
	'''
	pass
	'''
	
	def __init__(self):
		super(PosAlign, self).__init__()
		self.soft_plus = nn.Softplus()
	
	def forward(self, feature, target):
		feature = F.normalize(feature, p=2, dim=1)
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
		
		positive_pair = torch.masked_select(feature, label_matrix)
		
		# print("positive_pair.shape", positive_pair.shape)
		
		loss = 1. * self.soft_plus(torch.logsumexp(positive_pair, 0))
		
		return loss
