import math
import torch
import random
import torch.nn as nn  
    
    
class ReductionDump(nn.Module):
    """ We omit the use of horovod in this implementation. 
        Normally, (del_ratio_max+del_ratio_min) of the instances will be removed for next epoch.
        We can also use the pre-defined values (global variables) to filter the instances.
    """
    def __init__(self, 
                 global_max=1e8,
                 global_min=0.0,
                 del_ratio_max=0.25,
                 del_ratio_min=0.25,
                 ):   
        super(ReductionDump, self).__init__()    
        self.global_max = global_max
        self.global_min = global_min
        
        self.del_ratio_max = del_ratio_max
        self.del_ratio_min = del_ratio_min
        
    def forward(self, idx=None, text_tokens=None, losses: dict = None):
        """ Gather the metric from all workers. """       
        losses_image, losses_text = losses['i2t'], losses['t2i']        
                    
        if self.global_max < 1e8 or self.global_min > 0.0:
            cond_ill_image = torch.nonzero(losses_image > self.global_max).view(-1)
            cond_ill_text = torch.nonzero(losses_text > self.global_max).view(-1)
            
            cond_red_image = torch.nonzero(losses_image < self.global_min).view(-1)
            cond_red_text = torch.nonzero(losses_text < self.global_min).view(-1)
        else:
            sorted_indices_image = torch.argsort(losses_image, descending=True)
            sorted_indices_text = torch.argsort(losses_text, descending=True)
            
            cond_ill_image = sorted_indices_image[:int(len(sorted_indices_image)*self.del_ratio_max)]
            cond_ill_text = sorted_indices_text[:int(len(sorted_indices_text)*self.del_ratio_max)]
            
            cond_red_image = sorted_indices_image[-int(len(sorted_indices_image)*self.del_ratio_min):]
            cond_red_text = sorted_indices_text[-int(len(sorted_indices_text)*self.del_ratio_min):]
            
        # get the intersection of the two sets that both satisfy the conditions
        text_tokens = text_tokens[:, :5]
        text_tokens = text_tokens.detach().clone().cpu()
        idx = idx.detach().clone().cpu()
        
        cond_red_image = cond_red_image.detach().clone().cpu()
        cond_red_text = cond_red_text.detach().clone().cpu()
        cond_ill_image = cond_ill_image.detach().clone().cpu()
        cond_ill_text = cond_ill_text.detach().clone().cpu()
            
        # we combine these two as keys for searching
        red_image_set = [(id_key, tuple(token)) for id_key, token in zip(
            idx[cond_red_image].tolist(), text_tokens[cond_red_image].tolist())]
        red_text_set = [(id_key, tuple(token)) for id_key, token in zip(
            idx[cond_red_text].tolist(), text_tokens[cond_red_text].tolist())]
        redundant_set = set(red_image_set) & set(red_text_set)
        
        if len(redundant_set) < 0.9 * len(red_image_set):
            rest_len = len(red_image_set) - len(redundant_set)
            rest_set = (set(red_image_set) | set(red_text_set)) - redundant_set
            redundant_set.update(set(random.sample(rest_set, k=rest_len))) 
        
        ill_image_set = [(id_key, tuple(token)) for id_key, token in zip(
            idx[cond_ill_image].tolist(), text_tokens[cond_ill_image].tolist())]
        ill_text_set = [(id_key, tuple(token)) for id_key, token in zip(
            idx[cond_ill_text].tolist(), text_tokens[cond_ill_text].tolist())]
        ill_match_set = set(ill_image_set) & set(ill_text_set) 
        
        if len(ill_match_set) < 0.9 * len(ill_image_set):
            rest_len = len(ill_image_set) - len(ill_match_set)
            rest_set = (set(ill_image_set) | set(ill_text_set)) - ill_match_set
            ill_match_set.update(set(random.sample(rest_set, k=rest_len)))

        return redundant_set, ill_match_set
        
        
class CosineAnnealing(nn.Module):
    def __init__(self, T_max=3, eta_min=0):
        super(CosineAnnealing, self).__init__()
        self.T_max = T_max
        self.eta_min = eta_min
        self.rho_init = 1.0
        
    def forward(self, epoch):
        epoch %= (self.T_max + 1)
        decay_ratio = self.eta_min + self.rho_init * (0.5 * (
            1 + math.cos(torch.tensor(
            (self.T_max - epoch) / self.T_max * math.pi))))
        return decay_ratio
    