import torch
import numpy as np
from abc import ABC, abstractmethod, ABCMeta

class Metric(metaclass=ABCMeta):
    """
    -   reset() in the begining of every epoch.
    -   update_per_batch() after every batch.
    -   update_per_epoch() after every epoch.
    """

    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def reset(self):
        pass

    @abstractmethod
    def update_per_batch(self, output):
        pass

    @abstractmethod
    def update_per_epoch(self):
        pass

class Top_K_Metric(Metric):
    """
    Stores accuracy (score), loss and timing info
    """
    def __init__(self, topnum=[1,3,10]):
        super().__init__()
        self.topnum = topnum
        self.k_num = len(self.topnum)
        self.reset()

    def reset(self):
        self.total_loss = 0
        self.correct_list = [0] * self.k_num
        self.acc_list = [0] * self.k_num
        self.acc_all = 0
        self.num_examples = 0
        self.num_epoch = 0

        self.mrr = 0
        self.mr = 0
        self.mrr_all = 0
        self.mr_all = 0

    def update_per_batch(self, loss, ans, pred):
        self.total_loss += loss
        self.num_epoch += 1
        self.top_k_list = self.batch_accuracy(pred, ans)
        self.num_examples += self.top_k_list[0].shape[0]
        for i in range(self.k_num):
            self.correct_list[i] += self.top_k_list[i].sum().item()

        # mrr
        mrr_tmp, mr_tmp =  self.batch_mr_mrr(pred, ans)
        self.mrr_all += mrr_tmp.sum().item()
        self.mr_all += mr_tmp.sum().item()



    def update_per_epoch(self):
        for i in range(self.k_num):
            self.acc_list[i] = 100 * (self.correct_list[i] / self.num_examples)

        self.mr = self.mr_all / self.num_examples
        self.mrr = self.mrr_all / self.num_examples
        self.total_loss = self.total_loss / self.num_epoch
        self.acc_all = sum(self.acc_list)


    def batch_accuracy(self, predicted, true):
        """ Compute the accuracies for a batch of predictions and answers """
        if len(true.shape) == 3:
            true = true[0]
        _, ok = predicted.topk(max(self.topnum), dim=1)
        agreeing_all = torch.zeros([predicted.shape[0], 1], dtype=torch.float).cuda()
        top_k_list = [0]*self.topnum
        for i in range(max(self.topnum)):
            tmp = ok[:, i].reshape(-1, 1)
            agreeing_all += true.gather(dim=1, index=tmp)
            for k in range(self.k_num):
                if i == self.topnum[k] - 1:
                    top_k_list[k] = (agreeing_all * 0.3).clamp(max=1)
                    break

        return top_k_list



    def batch_mr_mrr(self, predicted, true):
        if len(true.shape) == 3:
            true = true[0]


        top_rank = predicted.shape[1]
        batch_size = predicted.shape[0]
        _, predict_ans_rank = predicted.topk(top_rank, dim=1) 
        _, real_ans = true.topk(1, dim=1) 


        real_ans = real_ans.expand(batch_size, top_rank)
        ans_different = torch.abs(predict_ans_rank - real_ans)

        _, real_ans_list = ans_different.topk(top_rank, dim=1) 
        real_ans_list = real_ans_list + 1.0
        mr = real_ans_list[:,-1].reshape(-1,1).to(torch.float64)
        mrr = 1.0 / mr

        return mrr,mr


if __name__ == '__main__':
    pass