import math
import random
import torch
import torch.nn as nn  
    

# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


class ReductionDump(nn.Module):
    """ We omit the use of horovod in this implementation. 
        Normally, (del_ratio_max+del_ratio_min) of the instances will be removed for next epoch.
        We can also use the pre-defined values (global variables) to filter the instances.
    """
    def __init__(self, 
                 global_max=1e8,
                 global_min=0.0,
                 del_ratio_max=0.25,
                 del_ratio_min=0.25,
                 ):   
        super(ReductionDump, self).__init__()    
        self.global_max = global_max
        self.global_min = global_min
        
        self.del_ratio_max = del_ratio_max
        self.del_ratio_min = del_ratio_min
        
    def forward(self, idx=None, losses: dict = None):
        """ Gather the metric from all workers. """       
        loss_1, loss_2 = losses['loss_1'], losses['loss_2']        
                    
        if self.global_max < 1e8 or self.global_min > 0.0:
            cond_ill_loss1 = torch.nonzero(loss_1 > self.global_max).view(-1)
            cond_ill_loss2 = torch.nonzero(loss_2 > self.global_max).view(-1)
            
            cond_red_loss1 = torch.nonzero(loss_1 < self.global_min).view(-1)
            cond_red_loss2 = torch.nonzero(loss_2 < self.global_min).view(-1)
        else:
            sorted_indices_loss1 = torch.argsort(loss_1, descending=True)
            sorted_indices_loss2 = torch.argsort(loss_2, descending=True)
            
            cond_ill_loss1 = sorted_indices_loss1[:int(len(sorted_indices_loss1)*self.del_ratio_max)]
            cond_ill_loss2 = sorted_indices_loss2[:int(len(sorted_indices_loss2)*self.del_ratio_max)]

            cond_red_loss1 = sorted_indices_loss1[-int(len(sorted_indices_loss1)*self.del_ratio_min):]
            cond_red_loss2 = sorted_indices_loss2[-int(len(sorted_indices_loss2)*self.del_ratio_min):]
            
        # get the intersection of the two sets that both satisfy the conditions
        idx = idx.detach().clone()
        
        cond_red_loss1 = cond_red_loss1.detach().clone()
        cond_red_loss2 = cond_red_loss2.detach().clone()
        cond_ill_loss1 = cond_ill_loss1.detach().clone()
        cond_ill_loss2 = cond_ill_loss2.detach().clone()
            
        # we combine these two as keys for searching
        red_loss1_set = idx[cond_red_loss1]
        red_loss2_set = idx[cond_red_loss2]
        ill_loss1_set = idx[cond_ill_loss1]
        ill_loss2_set = idx[cond_ill_loss2]
        
        if torch.distributed.get_world_size() > 1:
            red_loss1_set = concat_all_gather(red_loss1_set)
            red_loss2_set = concat_all_gather(red_loss2_set)
            ill_loss1_set = concat_all_gather(ill_loss1_set)
            ill_loss2_set = concat_all_gather(ill_loss2_set)

        red_loss1_set = red_loss1_set.cpu().numpy().tolist()
        red_loss2_set = red_loss2_set.cpu().numpy().tolist()
        ill_loss1_set = ill_loss1_set.cpu().numpy().tolist()
        ill_loss2_set = ill_loss2_set.cpu().numpy().tolist()
            
        redundant_set = set(red_loss1_set) & set(red_loss2_set)  
        if len(redundant_set) < 0.9 * len(red_loss1_set):
            rest_len = len(red_loss1_set) - len(redundant_set)
            rest_set = (set(red_loss1_set) | set(red_loss2_set)) - redundant_set
            redundant_set.update(set(random.sample(rest_set, k=rest_len))) 
                
        
        ill_match_set = set(ill_loss1_set) & set(ill_loss2_set) 
        if len(ill_match_set) < 0.9 * len(ill_loss1_set):
            rest_len = len(ill_loss1_set) - len(ill_match_set)
            rest_set = (set(ill_loss1_set) | set(ill_loss2_set)) - ill_match_set
            ill_match_set.update(set(random.sample(rest_set, k=rest_len)))
        
        return redundant_set, ill_match_set
        
        
class CosineAnnealing(nn.Module):
    def __init__(self, T_max=3, eta_min=0):
        super(CosineAnnealing, self).__init__()
        self.T_max = T_max
        self.eta_min = eta_min
        self.rho_init = 1.0
        
    def forward(self, epoch):
        epoch %= (self.T_max + 1)
        decay_ratio = self.eta_min + self.rho_init * (0.5 * (
            1 + math.cos(torch.tensor(
            (self.T_max - epoch) / self.T_max * math.pi))))
        return decay_ratio
    