'''Modified from https://github.com/alinlab/LfF/blob/master/util.py'''
import torch
import logging
import os

class EMA:
    def __init__(self, label, num_classes=None, alpha=0.9):
        self.label = label.cuda()
        self.alpha = alpha
        self.parameter = torch.zeros(label.size(0))
        self.updated = torch.zeros(label.size(0))
        self.num_classes = num_classes
        self.max = torch.zeros(self.num_classes).cuda()
        self.instance = torch.zeros(label.size(0)).cuda()


    def update(self, data, index, curve=None, iter_range=None, step=None):
        self.instance[index] = data.cuda()
        self.parameter = self.parameter.to(data.device)
        self.updated = self.updated.to(data.device)
        index = index.to(data.device)

        if curve is None:
            self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
        else:
            alpha = curve ** -(step / iter_range)
            self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data

        self.updated[index] = 1

    def max_loss(self, label):
        label_index = torch.where(self.label == label)[0]
        return self.parameter[label_index].max()


class Hook:
    def __init__(self, module, backward=False):
        self.feature = []
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
        self.feature.append(output)

    def close(self):
        self.hook.remove()


class logger:
    def __init__(self, log_dir, args):
            
        self.logger = logging.getLogger('Evaluation')
        self.logger.setLevel(logging.INFO)
        
        formatter = logging.Formatter('%(message)s')
        
        strm_handler = logging.StreamHandler()
        strm_handler.setFormatter(formatter)
        
        if args.log_overwrite:
            mode = 'w'
        else:
            mode = 'a'
        file_handler = logging.FileHandler(f'{log_dir}/log.txt', mode=mode)
        file_handler.setFormatter(formatter)
                        
        self.logger.addHandler(strm_handler)
        self.logger.addHandler(file_handler)

        message = f'---{args.dataset}---'
        self(message, level=1)
        self.arg_logging(args)

    def __call__(self,message, level=0):
        
        if level == 0:
            prefix = '' 
        elif level == 1:
            prefix = '--->' 
        else:
            prefix = ' '*level + '└' + '>'
        
        self.logger.info(f'{prefix}{message}')


    def arg_logging(self, argument):
        self('Argument', level=1)
        arg_dict = vars(argument)
        for key in arg_dict.keys():
            if key == 'logger':
                pass
            else:
                self(f'{key:12s}: {arg_dict[key]}', level=2)
