import json
import torch
from torch import nn


class EntropyCurriculum(nn.Module):
    def __init__(self, cfg_idx, epochs, 
            avgloss = False,
            alpha = 0.5,
            decay = 0.9, percentile = 0.7,
            cfg = None):
        super().__init__()

        self.register_buffer('avg', None)
        self.alpha = alpha
        self.decay = decay
        self.percentile = percentile
        self.avgloss = avgloss

        self.epochs = epochs
        if cfg:
            self.cfg = cfg
        else:
            with open('cfg/c1c2_%s.json'%cfg_idx) as f:
                cfg = json.load(f)
                self.cfg = {int(k): v for k,v in cfg.items()}

    def forward(self, loss, training_progress, ent_class):
        with torch.no_grad():
            if self.avg is None:
                self.avg = torch.quantile(loss, self.percentile)
            else:
                self.avg = self.decay * self.avg + (1 - self.decay) * torch.quantile(loss, self.percentile)

        entlist = ent_class.tolist()
        c1 = torch.tensor([self.cfg[c]['c1'] for c in entlist]).to(loss.device)
        c2 = torch.tensor([self.cfg[c]['c2'] for c in entlist]).to(loss.device)
        x = c1*(training_progress-c2)
        conf = torch.sigmoid(x)
        
        if self.avgloss:
            conf2 = torch.sigmoid(self.avg - loss)
            conf = self.alpha * conf + (1 - self.alpha) * conf2

        return conf
