from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def param_init(model) : 
    for m in model.modules():
        if isinstance(m, nn.Conv2d) :
            nn.init.normal_(m.weight.data, 0, 0.01)
        if isinstance(m, nn.ConvTranspose2d) :
            nn.init.normal_(m.weight.data, 0, 0.01)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear) : 
            nn.init.normal_(m.weight.data, 0, 0.01)
            if m.bias != None : 
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)



class conv_layer_module(nn.Module) : 
    def __init__(self, in_ch, out_ch, k, s, p, bias=False) : 
        super(conv_layer_module, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.bat = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x) : 
        x = self.conv(x)
        x = self.bat(x)
        x = self.relu(x)
        return x



class fc_layer_module(nn.Module) : 
    def __init__(self, in_dim, out_dim, bias=False) : 
        super(fc_layer_module, self).__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias)
        self.bat = nn.BatchNorm1d(out_dim)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x) : 
        x = self.fc(x)
        x = self.bat(x)
        x = self.relu(x)
        return x



class RelCoordNet(nn.Module) : 
    def __init__(self) : 
        super(RelCoordNet, self).__init__()
        self.conv = nn.Sequential(
            conv_layer_module(4, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 128))
        self.head = nn.Sequential(
            nn.Linear(128, 2))
        
        param_init(self)
        
    def forward(self, o, reference_c) : 
        z = F.normalize(self.proj(self.conv(torch.cat((o, reference_c), dim=1))), dim=1)
        relative_c = 2*torch.tanh(self.head(z))
        return z, relative_c



class PWConLoss(nn.Module):
    def __init__(self, mode, temperature=0.05):
        super(PWConLoss, self).__init__()
        self.temperature = temperature
        self.mode = mode
    
    def forward(self, features, 
                r1_r1_distance, r1_r1_relationship, 
                r1_r2_distance, r1_r2_relationship,
                r2_r2_distance, r2_r2_relationship):  
        
        B = features.size(0)    # Feature dim [B,2,d]
        view_size = 27
        w_factor = view_size * 4
        
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        
        if self.mode == 'one' : 
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.mode == 'all' : 
            anchor_feature = contrast_feature
            anchor_count = 2
        
        anchor_dot_contrast = torch.matmul(anchor_feature, contrast_feature.T) / self.temperature
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        weight_r1_r1 = 1+(w_factor - torch.abs(r1_r1_distance).sum(-1))/w_factor
        weight_r1_r1 = torch.where(weight_r1_r1 > 1.0, weight_r1_r1, torch.zeros_like(weight_r1_r1))
        mask_r1_r1 = weight_r1_r1 * r1_r1_relationship
        
        weight_r1_r2 = 1+(w_factor - torch.abs(r1_r2_distance).sum(-1))/w_factor
        weight_r1_r2 = torch.where(weight_r1_r2 > 1.0, weight_r1_r2, torch.zeros_like(weight_r1_r2))
        mask_r1_r2 = weight_r1_r2 * r1_r2_relationship
        
        mask = torch.cat((mask_r1_r1, mask_r1_r2), 1)
        
        if self.mode == 'all' : 
            mask_r2_r1 = mask_r1_r2.T
            weight_r2_r2 = 1+(w_factor - torch.abs(r2_r2_distance).sum(-1))/w_factor
            weight_r2_r2 = torch.where(weight_r2_r2 > 1.0, weight_r2_r2, torch.zeros_like(weight_r2_r2))
            mask_r2_r2 = weight_r2_r2 * r2_r2_relationship
            mask_ = torch.cat((mask_r2_r1, mask_r2_r2), 1)
            mask = torch.cat((mask, mask_), 0)
        
        logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(B*anchor_count).view(-1, 1).to(device), 0)
        mask = mask * logits_mask
        
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1E-6)
        
        num_positive = (mask != 0).sum(1)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (num_positive + 1E-6)
        
        loss = -1 * mean_log_prob_pos.view(anchor_count, B).mean()
        return loss



