import torch
import torch.nn as nn
import torch.nn.functional as F

import torchsort
from torchmetrics.functional import pairwise_cosine_similarity
from torchmetrics.functional import pairwise_euclidean_distance
import numpy as np

def compute_martingale(detect_features_all, device, temperature, softrank_regularization_type, softrank_regularization_factor):
    """Compute martingale loss."""
    
    # Detection length
    detect_len = detect_features_all.shape[0]
    
    # Compute pairwise distances between features
    dists_ij_all = 1 - pairwise_cosine_similarity(detect_features_all, detect_features_all)   
    
    # Make diagonal infinity
    diagonal_mask = torch.eye(detect_len, device=device).bool()
    dists_ij_all = torch.where(diagonal_mask, torch.inf, dists_ij_all)
         
    # Make version with zeros on diagonal
    dists_ij_all_zeroed = dists_ij_all.clone()
    dists_ij_all_zeroed[torch.arange(detect_len), torch.arange(detect_len)] = 0.0 
    
    # Initialize martingale parameters
    martingale_soft = torch.ones(1, device=device)
    martingale_hard = 1.0
    E = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device)
    E_hard = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device) 
    C_eps = (1/len(E))*torch.ones(len(E), device=device)
    C_eps_hard = (1/len(E))*torch.ones(len(E), device=device) 
    J = 0.005 
    
    # Average martingale
    martingale_soft_av = torch.zeros(1, device=device)
    martingale_hard_av = torch.zeros(1, device=device)
    
    # Maximum martingale
    martingale_hard_max = -np.inf

    
    for i in range(1, detect_len):
        
        # Get distance between all pairs of features in detect_features
        dists_ij = dists_ij_all[:i+1,:i+1]
        
        # Create a version of dists_ij where elements are zero along diagonal
        dists_ij_zeroed = dists_ij_all_zeroed[:i+1,:i+1]
        
        # Random number for smoothing conformal p-values
        tau = torch.rand(1)[0] 
        
        ############################################################
        # For each j <= i, compute the nearest neighbor distance; for each row of dists_ij, compute softmin across columns
        softmin_dists = torch.sum(F.softmax(-temperature*dists_ij, dim=1)*dists_ij_zeroed, dim=1)
        
        # Compute p-values using soft ranking
        # p-value for i-th datapoint is the fraction of datapoints with smaller non-conformity scores
        regularization_strength = softrank_regularization_factor*detect_len/(i+1)
                                                    
        # Extract softmin distances for i-th datapoint
        softmin_dists_i = softmin_dists[:i+1]
                                        
        p_value_i = torchsort.soft_rank(softmin_dists_i.view(1, i+1) , regularization_strength=regularization_strength, regularization=softrank_regularization_type)[0][-1]-1
        p_value_i /= (i+1)
        p_value_i += tau*len(torch.where(softmin_dists_i == softmin_dists[i])[0])/(i+1)
        
        # Update soft martingale
        C_eps = (1-J)*C_eps + (J/len(E))*martingale_soft
        f_eps = 1 + E*(p_value_i - 0.5)
        C_eps = C_eps*f_eps
        martingale_soft = torch.sum(C_eps)   
        ############################################################
    
        ############################################################
        # Calculate actual martingale
        non_conformity_scores = torch.min(dists_ij, axis=1).values
        p_value_i_hard = len(torch.where(non_conformity_scores[:i+1] < non_conformity_scores[i])[0])
        p_value_i_hard /= (i+1)
        p_value_i_hard += tau*len(torch.where(non_conformity_scores[:i+1] == non_conformity_scores[i])[0])/(i+1)
        
        C_eps_hard = (1-J)*C_eps_hard + (J/len(E_hard))*martingale_hard
        f_eps_hard = 1 + E_hard*(p_value_i_hard - 0.5)
        C_eps_hard = C_eps_hard*f_eps_hard
        martingale_hard = sum(C_eps_hard)   
        ############################################################
        
        ############################################################
        # Update average martingale
        martingale_soft_av += martingale_soft / detect_len    
        martingale_hard_av += martingale_hard / detect_len
        ############################################################
        
        ############################################################
        # Update maximum martingale
        if martingale_hard > martingale_hard_max:
            martingale_hard_max = martingale_hard.item()
        ############################################################
    
    return martingale_soft_av, martingale_hard_av.item(), martingale_hard_max


def compute_martingale_batch(detect_features_all, device, temperature, softrank_regularization_type, softrank_regularization_factor):
    """Compute martingale loss."""
    
    # Detection length
    detect_len = detect_features_all.shape[0]
    
    # Compute pairwise distances between features
    dists_ij_all = 1 - pairwise_cosine_similarity(detect_features_all, detect_features_all)    
    
    # Make diagonal infinity
    dists_ij_all[torch.eye(detect_len, device=device).bool()] = torch.inf
    
    # Make version with zeros on diagonal
    dists_ij_all_zeroed = dists_ij_all.clone()
    dists_ij_all_zeroed[torch.arange(detect_len), torch.arange(detect_len)] = 0.0
    
    # Initialize martingale parameters
    martingale_soft = torch.ones(1, device=device)
    martingale_hard = 1.0
    E = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device)
    E_hard = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device) 
    C_eps = (1/len(E))*torch.ones(len(E), device=device)
    C_eps_hard = (1/len(E))*torch.ones(len(E), device=device) 
    J = 0.005 
    
    # Average martingale
    martingale_soft_av = torch.zeros(1, device=device)
    martingale_hard_av = torch.zeros(1, device=device)
    
    # Maximum martingale
    martingale_hard_max = -np.inf
    
    for i in range(1, detect_len):
        
        # Get distance between all pairs of features in detect_features
        dists_ij = dists_ij_all[:i+1,:i+1]
        
        # Create a version of dists_ij where elements are zero along diagonal
        dists_ij_zeroed = dists_ij_all_zeroed[:i+1,:i+1]
        
        # Random number for smoothing conformal p-values
        tau = torch.rand(1)[0] 
        
        ############################################################
        # For each j <= i, compute the nearest neighbor distance; for each row of dists_ij, compute softmin across columns
        softmin_dists = torch.sum(F.softmax(-temperature*dists_ij, dim=1)*dists_ij_zeroed, dim=1)
        
        # Compute p-values using soft ranking
        # p-value for i-th datapoint is the fraction of datapoints with smaller non-conformity scores
        regularization_strength = softrank_regularization_factor*detect_len/(i+1)
                                                    
        # Extract softmin distances for i-th datapoint
        softmin_dists_i = softmin_dists[:i+1]
                                        
        p_value_i = torchsort.soft_rank(softmin_dists_i.view(1, i+1) , regularization_strength=regularization_strength, regularization=softrank_regularization_type)[0][-1]-1
        p_value_i /= (i+1)
        p_value_i += tau*len(torch.where(softmin_dists_i == softmin_dists[i])[0])/(i+1)
        
        # Update soft martingale
        C_eps = (1-J)*C_eps + (J/len(E))*martingale_soft
        f_eps = 1 + E*(p_value_i - 0.5)
        C_eps = C_eps*f_eps
        martingale_soft = torch.sum(C_eps)   
        ############################################################
    
        ############################################################
        # Calculate actual martingale
        non_conformity_scores = torch.min(dists_ij, axis=1).values
        p_value_i_hard = len(torch.where(non_conformity_scores[:i+1] < non_conformity_scores[i])[0])
        p_value_i_hard /= (i+1)
        p_value_i_hard += tau*len(torch.where(non_conformity_scores[:i+1] == non_conformity_scores[i])[0])/(i+1)
        
        C_eps_hard = (1-J)*C_eps_hard + (J/len(E_hard))*martingale_hard
        f_eps_hard = 1 + E_hard*(p_value_i_hard - 0.5)
        C_eps_hard = C_eps_hard*f_eps_hard
        martingale_hard = sum(C_eps_hard)   
        ############################################################
        
        ############################################################
        # Update average martingale
        martingale_soft_av += martingale_soft / detect_len    
        martingale_hard_av += martingale_hard / detect_len
        ############################################################
        
        ############################################################
        # Update maximum martingale
        if martingale_hard > martingale_hard_max:
            martingale_hard_max = martingale_hard.item()
        ############################################################
    
    return martingale_soft_av, martingale_hard_av.item(), martingale_hard_max




def compute_martingale_hard(detect_features_all, device):
    """Compute actual martingale."""
    
    # Detection length
    detect_len = detect_features_all.shape[0]
    
    # Compute pairwise distances between features
    dists_ij_all = 1 - pairwise_cosine_similarity(detect_features_all, detect_features_all)
    
    # Make diagonal infinity
    dists_ij_all[torch.eye(detect_len, device=device).bool()] = torch.inf
    
    # Make version with zeros on diagonal
    dists_ij_all_zeroed = dists_ij_all.clone()
    dists_ij_all_zeroed[torch.arange(detect_len), torch.arange(detect_len)] = 0.0
    
    # Initialize martingale parameters
    martingale_hard = 1.0
    E = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device)
    E_hard = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device) 
    C_eps = (1/len(E))*torch.ones(len(E), device=device)
    C_eps_hard = (1/len(E))*torch.ones(len(E), device=device) 
    J = 0.005 
    
    # Average martingale
    martingale_hard_av = torch.zeros(1, device=device)
    
    # Maximum martingale
    martingale_hard_max = -np.inf
    
    # Martingale all
    martingale_hard_all = torch.zeros(detect_len, device=device)

    for i in range(1, detect_len):
        
        # Get distance between all pairs of features in detect_features
        dists_ij = dists_ij_all[:i+1,:i+1]
        
        # Random number for smoothing conformal p-values
        tau = torch.rand(1)[0] 
    
        ############################################################
        # Calculate actual martingale
        non_conformity_scores = torch.min(dists_ij, axis=1).values
        p_value_i_hard = len(torch.where(non_conformity_scores[:i+1] < non_conformity_scores[i])[0])
        p_value_i_hard /= (i+1)
        p_value_i_hard += tau*len(torch.where(non_conformity_scores[:i+1] == non_conformity_scores[i])[0])/(i+1)
        
        C_eps_hard = (1-J)*C_eps_hard + (J/len(E_hard))*martingale_hard
        f_eps_hard = 1 + E_hard*(p_value_i_hard - 0.5)
        C_eps_hard = C_eps_hard*f_eps_hard
        martingale_hard = sum(C_eps_hard)   
        ############################################################
        
        ############################################################
        # Update average martingale
        martingale_hard_all[i] = martingale_hard
        martingale_hard_av += martingale_hard / detect_len
        ############################################################
        
        ############################################################
        # Update maximum martingale
        if martingale_hard > martingale_hard_max:
            martingale_hard_max = martingale_hard.item()
        ############################################################
    
    return martingale_hard_av.item(), martingale_hard_max, martingale_hard_all

class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    Adapted from https://github.com/adambielski/siamese-triplet/blob/master/losses.py
    """

    def __init__(self, margin=1):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, distances, target, size_average=True):
        losses = 0.5 * (target.float() * distances + (1 + -1 * target).float()
                        * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        return losses.mean() if size_average else losses.sum()

