import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float


def param_init(model) : 
    for m in model.modules():
        if isinstance(m, nn.Conv2d) :
            nn.init.normal_(m.weight.data, 0, 0.01)
        if isinstance(m, nn.ConvTranspose2d) :
            nn.init.normal_(m.weight.data, 0, 0.01)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear) : 
            nn.init.normal_(m.weight.data, 0, 0.01)
            if m.bias != None : 
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)



class conv_layer_module(nn.Module) : 
    def __init__(self, in_ch, out_ch, k, s, p, bias=False) : 
        super(conv_layer_module, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.bat = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x) : 
        x = self.conv(x)
        x = self.bat(x)
        x = self.relu(x)
        return x



class fc_layer_module(nn.Module) : 
    def __init__(self, in_dim, out_dim, bias=False) : 
        super(fc_layer_module, self).__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias)
        self.bat = nn.BatchNorm1d(out_dim)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x) : 
        x = self.fc(x)
        x = self.bat(x)
        x = self.relu(x)
        return x

    

class FeatNet(nn.Module) : 
    def __init__(self) : 
        super(FeatNet, self).__init__()
        self.conv = nn.Sequential(
            conv_layer_module(6, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 256),
            nn.Linear(256, 128))
        
        param_init(self)
        
    def forward(self, x) : 
        z = F.normalize(self.proj(self.conv(x)), dim=1)
        return z
    
    

class PWConLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(PWConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, anchor_labels, a_l_distance, a_l_relationship, r_l_distance, r_l_relationship):  
        B = features.size(0)
        
        anchor_feature = features[:, 0]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        
        anchor_dot_contrast = torch.matmul(anchor_feature, contrast_feature.T) / self.temperature
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        a_labels = anchor_labels.view(-1, 1).to(device)
        mask_aa = torch.eq(a_labels, a_labels.T).float().to(device)
        mask_ap = r_l_relationship.T[a_labels.view(-1)]
        
        weight_anchor_l = 1 + torch.exp(-0.025*torch.abs(a_l_distance).sum(-1))
        weight_anchor_l = torch.where(weight_anchor_l >= 1.0, weight_anchor_l, torch.zeros_like(weight_anchor_l))
        weight_anchor_l = weight_anchor_l * a_l_relationship
        mask_aa = weight_anchor_l.T[a_labels.view(-1)]
        
        weight_random_l = 1 + torch.exp(-0.025*torch.abs(r_l_distance).sum(-1))
        weight_random_l = torch.where(weight_random_l >= 1.0, weight_random_l, torch.zeros_like(weight_random_l))
        weight_random_l = weight_random_l * r_l_relationship
        mask_ap = weight_random_l.T[a_labels.view(-1)]
        
        mask = torch.cat((mask_aa, mask_ap), 1)
        logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(B).view(-1, 1).to(device), 0)
        mask = mask * logits_mask
        
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1E-6)
        
        num_positive = (mask != 0).sum(1)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (num_positive + 1E-6)
        
        loss = -1 * mean_log_prob_pos.view(1, B).mean()
        return loss








