import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchsort
from torchmetrics.functional import pairwise_cosine_similarity

####################################################################################
def test_model(model, device, test_loader, set_name="test set"):
  '''
  Evaluate the performance of the model on a given test dataset.
  '''
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device).float()
      try:
          output, _ = model(data)
      except:
          output = model(data)
              
      test_loss += F.binary_cross_entropy_with_logits(output, target, reduction='sum').item()  # sum up batch loss
      pred = torch.where(torch.gt(output, torch.Tensor([0.0]).to(device)),
                         torch.Tensor([1.0]).to(device),
                         torch.Tensor([0.0]).to(device))  # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)

  print('\nPerformance on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
    set_name, test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

  return 100. * correct / len(test_loader.dataset)


def erm_train(model, device, train_loader, optimizer, epoch):
  '''
  One step of ERM training. 
  '''
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device).float()
    optimizer.zero_grad()
    # output = model(data)
    try:
      output, _ = model(data)
    except:
      output = model(data)
    loss = F.binary_cross_entropy_with_logits(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % 10 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
               100. * batch_idx / len(train_loader), loss.item()))

def train_and_test_erm(train_loader, test_loader, model, num_training_epochs=5, lr=0.005):
  '''
  Train and test with Empirical Risk Minimization (ERM).
  '''
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")

  optimizer = optim.Adam(model.parameters(), lr=lr)

  for epoch in range(1, num_training_epochs+1):
    erm_train(model, device, train_loader, optimizer, epoch)
    train_success = test_model(model, device, train_loader, set_name='train set')
    test_success = test_model(model, device, test_loader)
    
  # Save model
  torch.save(model.state_dict(), "model_erm.pth")
  
  
  return train_success, test_success
    
    
#########################################################################################
def train_and_test_drm(train_loader, test_loader, detect_loader, model, martingale_penalty, temperature=1.0, softrank_regularization_factor=0.001, num_training_epochs=5, lr=0.005, softrank_regularization_type="l2", l2_penalty=0.0, num_epochs_erm=0):
  '''
  Train and test with Deceptive Risk Minimization (DRM).
  '''
  
  # Device
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")

  optimizer = optim.Adam(model.parameters(), lr=lr) 

  for epoch in range(1, num_training_epochs+1):
    
    ############################################################
    if epoch <= num_epochs_erm:
      martingale_penalty_in = 0.0
    else:
      martingale_penalty_in = martingale_penalty
    ############################################################
      
    martingale_av, martingale_max = drm_train_step(model, device, train_loader, detect_loader, test_loader, optimizer, epoch, martingale_penalty_in, temperature, softrank_regularization_factor, softrank_regularization_type, l2_penalty)
    train_success = test_model(model, device, train_loader, set_name='train set')
    test_success = test_model(model, device, test_loader)
    
    # Save model
    torch.save(model.state_dict(), "model_drm.pth")
    
  return train_success, test_success

#########################################################################################

def drm_train_step(model, device, train_loader, detect_loader, test_loader, optimizer, epoch, martingale_penalty, temperature, softrank_regularization_factor, softrank_regularization_type, l2_penalty):
    '''
    One step of DRM training.
    '''
        
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # Compute cross entropy loss
        data, target = data.to(device), target.to(device).float()
        optimizer.zero_grad()
        output, _ = model(data)
        erm_loss = F.binary_cross_entropy_with_logits(output, target)

        # Load detection set for this batch
        detect_data_batch, detect_target_batch, indices_batch = next(iter(detect_loader))
        detect_data_batch, detect_target_batch, indices_batch = detect_data_batch.to(device), detect_target_batch.to(device).float(), indices_batch.to(device)
                                
        # Permute detection data so that indices are in increasing order
        detect_data_batch = detect_data_batch[torch.argsort(indices_batch)]
        detect_target_batch = detect_target_batch[torch.argsort(indices_batch)]
        detect_len = len(detect_target_batch)
                
        # Initialize martingale parameters
        martingale_soft = torch.ones(1, device=device)
        martingale_hard = 1.0
        E = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device)
        E_hard = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device) 
        C_eps = (1/len(E))*torch.ones(len(E), device=device)
        C_eps_hard = (1/len(E))*torch.ones(len(E), device=device) 
        J = 0.005 
        
        # Average martingale
        martingale_soft_av = torch.zeros(1, device=device)
        martingale_hard_av = torch.zeros(1, device=device)
        
        # Maximum martingale
        martingale_hard_max = -np.inf
                
        # Get features
        _, detect_features_all = model(detect_data_batch)
        
        
        # Compute pairwise distances
        # dists_ij_all = torch.cdist(detect_features_all, detect_features_all) # Euclidean distance
        dists_ij_all = 1 - pairwise_cosine_similarity(detect_features_all, detect_features_all) # Cosine distance
                
        for i in range(10, detect_len): # Starting with 10 to avoid infs/nans
            # Get targets from all images with index <= i
            detect_target = detect_target_batch[:i+1]
            
            # Number of datapoints with matching labels
            num_labels_ij_eq = len(torch.where(detect_target[i] == detect_target)[0])
            
            # Compute distance between all pairs of features in detect_features
            dists_ij = dists_ij_all[:i+1,:i+1]
            
            # For each element, if labels match, compute distance; if labels don't match, set distance to be infinity
            # Do it in a vectorized way
            label_match = detect_target.view(-1,1) == detect_target.view(1,-1)
            dists_ij_prime = dists_ij.clone()
            dists_ij_prime[~label_match] = torch.inf
            
            # Make diagonal infinity
            dists_ij_prime[torch.eye(i+1, device=device).bool()] = torch.inf
            
            # Create a version of dists_ij_prime where elements are zero if labels don't match and along diagonal
            dists_ij_zeroed = dists_ij*label_match
            dists_ij_zeroed = dists_ij_zeroed * (1 - torch.eye(i+1, device=device))
            
            # Random number for smoothing conformal p-values
            tau = torch.rand(1)[0] 
            
            ############################################################
            # For each j <= i, compute the nearest neighbor distance; for each row of dists_ij, compute softmin across columns
            # temperature: higher is closer to hard minimum
            softmin_dists = torch.sum(F.softmax(-temperature*dists_ij_prime, dim=1)*dists_ij_zeroed, dim=1)
            
            # Compute p-values using soft ranking
            # p-value for i-th datapoint is the fraction of datapoints with smaller non-conformity scores
            regularization_strength = softrank_regularization_factor*detect_len/(i+1)
                                                     
            # Extract softmin distances for i-th datapoint where labels match
            softmin_dists_i = softmin_dists[:i+1][detect_target[:i+1] == detect_target[i]]
                                         
            p_value_i = torchsort.soft_rank(softmin_dists_i.view(1,num_labels_ij_eq), regularization_strength=regularization_strength, regularization=softrank_regularization_type)[0][-1]-1
            p_value_i /= num_labels_ij_eq
            p_value_i += tau*len(torch.where(softmin_dists_i == softmin_dists[i])[0])/num_labels_ij_eq
            
            # Update soft martingale
            C_eps = (1-J)*C_eps + (J/len(E))*martingale_soft
            f_eps = 1 + E*(p_value_i - 0.5)
            C_eps = C_eps*f_eps
            martingale_soft = torch.sum(C_eps)   
            ############################################################
            
            ############################################################
            # Calculate actual martingale
            non_conformity_scores = torch.min(dists_ij_prime, axis=1).values
            p_value_i_hard = len(torch.where(torch.logical_and(detect_target[i] == detect_target, non_conformity_scores[:i+1] < non_conformity_scores[i]))[0])
            p_value_i_hard /= num_labels_ij_eq
            p_value_i_hard += tau*len(torch.where(torch.logical_and(detect_target[i] == detect_target, non_conformity_scores[:i+1] == non_conformity_scores[i]))[0])/num_labels_ij_eq           
            
            C_eps_hard = (1-J)*C_eps_hard + (J/len(E_hard))*martingale_hard
            f_eps_hard = 1 + E_hard*(p_value_i_hard - 0.5)
            C_eps_hard = C_eps_hard*f_eps_hard
            martingale_hard = sum(C_eps_hard)   
            ############################################################
            
            ############################################################
            # Update average martingale
            martingale_soft_av += martingale_soft / detect_len    
            martingale_hard_av += martingale_hard / detect_len
            ############################################################
            
            ############################################################
            # Update maximum martingale
            if martingale_hard > martingale_hard_max:
              martingale_hard_max = martingale_hard.item()
            ############################################################
        
        
        ############################################################
        # Print progress
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tERM Loss: {:.6f}\tMartingale av. (soft): {:.6f}\tMartingale av.: {:.6f}\tMartingale max.: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), erm_loss.item(), martingale_soft_av.item(), martingale_hard_av.item(), martingale_hard_max))
        
        if batch_idx % 10 == 0:
          test_model(model, device, train_loader, set_name='train set')
          test_model(model, device, test_loader)
        ############################################################
          
        ############################################################
                    
        # Optimization step
        loss = 0.001*erm_loss + martingale_penalty*martingale_soft_av
        
        loss.backward()
        optimizer.step()
        ############################################################
        
    return martingale_hard_av.item(), martingale_hard_max
  
  
#########################################################################################

#########################################################################################
def train_and_test_drm_batch(train_loader, test_loader, detect_loader, model, martingale_penalty, temperature=1.0, softrank_regularization_factor=0.0001, num_training_epochs=5, lr=0.005, softrank_regularization_type="l2", l2_penalty=0.0, num_epochs_erm=0, num_detect_sets=5, detect_batch_size=1000):
  '''
  DRM training with multiple batches of detection sets.
  '''
    
  # Device
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda" if use_cuda else "cpu")

  # Define optimizer
  optimizer = optim.Adam(model.parameters(), lr=lr) 
  
  ############################################################
  # Load multiple detection sets
  detect_loader_iter = iter(detect_loader)

  detect_sets = torch.zeros(num_detect_sets, detect_batch_size, 3, 28, 28).to(device)
  detect_targets = torch.zeros(num_detect_sets, detect_batch_size).to(device)
  
  for i in range(0, num_detect_sets):
    detect_sets_i, detect_targets_i, detect_indices_i = next(detect_loader_iter)
    detect_indices_i = torch.argsort(detect_indices_i)
    detect_sets[i] = detect_sets_i[detect_indices_i]
    detect_targets[i] = detect_targets_i[detect_indices_i]
  ############################################################

  for epoch in range(1, num_training_epochs+1):
    
    ############################################################
    if epoch <= num_epochs_erm:
      martingale_penalty_in = 0.0
    else:
      martingale_penalty_in = martingale_penalty
    ############################################################
      
    martingale_av, martingale_max = drm_train_step_batch(model, device, train_loader, test_loader, optimizer, epoch, martingale_penalty_in, temperature, softrank_regularization_factor, softrank_regularization_type, l2_penalty, num_detect_sets, detect_batch_size, detect_sets, detect_targets, num_epochs_erm)
    train_success = test_model(model, device, train_loader, set_name='train set')
    test_success = test_model(model, device, test_loader)
    
    # Save model
    torch.save(model.state_dict(), "model_drm.pth")
    
    
  return train_success, test_success

def drm_train_step_batch(model, device, train_loader, test_loader, optimizer, epoch, martingale_penalty, temperature, softrank_regularization_factor, softrank_regularization_type, l2_penalty, num_detect_sets, detect_batch_size, detect_sets, detect_targets, num_epochs_erm):
    '''
    One step of DRM training with multiple batches of detection sets.
    '''
        
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        
        ############################################################
        # Compute cross entropy loss
        data, target = data.to(device), target.to(device).float()
        optimizer.zero_grad()
        output, _ = model(data)
        erm_loss = F.binary_cross_entropy_with_logits(output, target)
        
        ############################################################
        
        ############################################################
        # Loop over detection sets
        
        # Overall average and maximum martingales
        martingale_soft_av = torch.zeros(1, device=device)
        martingale_hard_av = torch.zeros(1, device=device)
        martingale_hard_max = -np.inf
        
        for ii in range(num_detect_sets):       
          
          # If ERM, skip detection set computations
          if epoch <= num_epochs_erm:
            continue

          detect_len = detect_batch_size
          
          detect_data_batch = detect_sets[ii]
          detect_target_batch = detect_targets[ii]
                  
          # Initialize martingale parameters
          martingale_soft = torch.ones(1, device=device)
          martingale_hard = 1.0
          E = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device)
          E_hard = torch.tensor([-1.0, -0.5, 0.0, 0.5, 1.0], device=device) 
          C_eps = (1/len(E))*torch.ones(len(E), device=device)
          C_eps_hard = (1/len(E))*torch.ones(len(E), device=device) 
          J = 0.005 
                  
          # Get features
          _, detect_features_all = model(detect_data_batch)
          
          
          # Compute pairwise distances
          # Sharpened cosine distance
          cs_matrix = pairwise_cosine_similarity(detect_features_all, detect_features_all)
          sign_matrix = torch.sign(cs_matrix)
          dists_ij_all = 1 - torch.mul(sign_matrix, torch.pow(cs_matrix, 2)) # elementwise multiplication
                  
          for i in range(10, detect_len):
              # Get targets from all images with index <= i
              detect_target = detect_target_batch[:i+1]
              
              # Number of datapoints with matching labels
              num_labels_ij_eq = len(torch.where(detect_target[i] == detect_target)[0])
              
              # Compute distance between all pairs of features in detect_features
              dists_ij = dists_ij_all[:i+1,:i+1]
              
              # For each element, if labels match, compute distance; if labels don't match, set distance to be infinity
              # Do it in a vectorized way
              label_match = detect_target.view(-1,1) == detect_target.view(1,-1)
              dists_ij_prime = dists_ij.clone()
              dists_ij_prime[~label_match] = torch.inf
              
              # Make diagonal infinity
              dists_ij_prime[torch.eye(i+1, device=device).bool()] = torch.inf
              
              # Create a version of dists_ij_prime where elements are zero if labels don't match and along diagonal
              dists_ij_zeroed = dists_ij*label_match
              dists_ij_zeroed = dists_ij_zeroed * (1 - torch.eye(i+1, device=device))
              
              # Random number for smoothing conformal p-values
              tau = torch.rand(1)[0] 
              
              ############################################################
              # For each j <= i, compute the nearest neighbor distance; for each row of dists_ij, compute softmin across columns
              # temperature: higher is closer to hard minimum
              softmin_dists = torch.sum(F.softmax(-temperature*dists_ij_prime, dim=1)*dists_ij_zeroed, dim=1)
              
              # Compute p-values using soft ranking
              # p-value for i-th datapoint is the fraction of datapoints with smaller non-conformity scores
              regularization_strength = softrank_regularization_factor*detect_len/(i+1)
                                                      
              # Extract softmin distances for i-th datapoint where labels match
              softmin_dists_i = softmin_dists[:i+1][detect_target[:i+1] == detect_target[i]]
                                          
              p_value_i = torchsort.soft_rank(softmin_dists_i.view(1,num_labels_ij_eq), regularization_strength=regularization_strength, regularization=softrank_regularization_type)[0][-1]-1
              p_value_i /= num_labels_ij_eq
              p_value_i += tau*len(torch.where(softmin_dists_i == softmin_dists[i])[0])/num_labels_ij_eq
              
              # Update soft martingale
              C_eps = (1-J)*C_eps + (J/len(E))*martingale_soft
              f_eps = 1 + E*(p_value_i - 0.5)
              C_eps = C_eps*f_eps
              martingale_soft = torch.sum(C_eps)   
              ############################################################
              
              ############################################################
              # Calculate actual martingale
              non_conformity_scores = torch.min(dists_ij_prime, axis=1).values
              p_value_i_hard = len(torch.where(torch.logical_and(detect_target[i] == detect_target, non_conformity_scores[:i+1] < non_conformity_scores[i]))[0])
              p_value_i_hard /= num_labels_ij_eq
              p_value_i_hard += tau*len(torch.where(torch.logical_and(detect_target[i] == detect_target, non_conformity_scores[:i+1] == non_conformity_scores[i]))[0])/num_labels_ij_eq           
              
              C_eps_hard = (1-J)*C_eps_hard + (J/len(E_hard))*martingale_hard
              f_eps_hard = 1 + E_hard*(p_value_i_hard - 0.5)
              C_eps_hard = C_eps_hard*f_eps_hard
              martingale_hard = sum(C_eps_hard)   
              ############################################################
              
              ############################################################
              # Update overall average martingale
              martingale_soft_av += martingale_soft / (detect_len * num_detect_sets)    
              martingale_hard_av += martingale_hard / (detect_len * num_detect_sets)
              ############################################################
              
              ############################################################
              # Update maximum martingale
              if martingale_hard > martingale_hard_max:
                martingale_hard_max = martingale_hard.item()
              ############################################################
        
        
        ############################################################
        # Print progress
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tERM Loss: {:.6f}\tMartingale av. (soft): {:.6f}\tMartingale av.: {:.6f}\tMartingale max.: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), erm_loss.item(), martingale_soft_av.item(), martingale_hard_av.item(), martingale_hard_max))
        
        if batch_idx % 10 == 0:
          test_model(model, device, train_loader, set_name='train set')
          test_model(model, device, test_loader)
        ############################################################
          
        ############################################################
                    
        # Optimization step
        loss = 0.001*erm_loss + martingale_penalty*martingale_soft_av
        
        loss.backward()
        optimizer.step()
        ############################################################
        
    return martingale_hard_av.item(), martingale_hard_max