import torch
from torch import nn
import random


class L2E(nn.Module):
    def __init__(self, h_upper=10, k=2., h_mean=None, h_var=None, thr_v=None):
        super(L2E, self).__init__()
        self.h_var = h_var
        self.h_mean = h_mean
        self.thr_v = None  # 5.3152
        self.h_count = 0
        self.h_upper = h_upper
        self.k = k/100.0

    def forward(self, h):
        h = h.transpose(0, 1)
        dtype_ = h.dtype
        device_ = h.device
        loss = torch.tensor(0.0, device=device_, dtype=dtype_)
        
        if self.training:
            h_ = torch.tensor(h.clone().detach(), dtype=torch.float32)
            h_u = h_.mean(dim=0, keepdim=True)
            h_v = h_.var(dim=0, keepdim=True, unbiased=False)
            _batch_size = h_.shape[0]
            
            # if str(h.device) == 'cuda:0':
            #     print('new var:', h_v.mean().item(), 'new mean:', h_u.mean().item())
            
            if self.h_var is None:
                # warm up
                self.h_var = h_v / (_batch_size - 1.) * _batch_size
                self.h_mean = h_u
                dev = ((h_ - self.h_mean).pow(2) / (self.h_var + self.h_var.mean() / 100)).reshape(-1)
                n_ele = dev.shape[0]
                n_inh = int(n_ele*self.k)
                self.thr_v = -torch.kthvalue(-dev, n_inh, keepdim=True)[0]
                
            elif self.h_count <= self.h_upper:
                # warm up
                self.h_var = (self.h_var * (self.h_count - 1. / _batch_size) + h_v +
                              (h_u - self.h_mean) ** 2 / (1 + 1. / self.h_count)) / \
                             (self.h_count + 1 - 1. / _batch_size)
                self.h_mean = (self.h_mean * self.h_count + h_u) / (self.h_count + 1.)
                
                dev = ((h_ - self.h_mean).pow(2) / (self.h_var + self.h_var.mean() / 100)).reshape(-1)
                n_ele = dev.shape[0]
                n_inh = int(n_ele*self.k)
                self.thr_v = (self.thr_v * self.h_count - torch.kthvalue(-dev, n_inh, keepdim=True)[0]) / (self.h_count + 1.)
                
            else:
                self.h_var = (self.h_var * (self.h_upper - 1. / _batch_size) + h_v +
                              (h_u - self.h_mean) ** 2 / (1 + 1. / self.h_upper)) / \
                             (self.h_upper + 1 - 1. / _batch_size)
                self.h_mean = (self.h_mean * self.h_upper + h_u) / (self.h_upper + 1.)
                
                # calculate loss
                loss = self.punish_percent_k(h, log=self.h_count%3200==0)

            self.h_count = self.h_count+1
            if self.h_var.mean().isfinite() is False:
                raise ValueError('h_var is not finite')
                
        return loss

    def punish_percent_k(self, h, log=False):
        # calculate ms score
        dev_g = (h - self.h_mean).pow(2)
        dev_ = dev_g.clone().detach() / (self.h_var + self.h_var.mean() / 100) 
        update_idx = torch.where(dev_ >= self.thr_v, torch.ones_like(dev_), torch.zeros_like(dev_))
        n_filtered = torch.sum(update_idx)
    
        L2E_loss = torch.sum(torch.log(dev_g * update_idx + torch.ones_like(dev_)*(1-update_idx))) 
        
        self.thr_v = self.thr_v - (int(dev_.nelement()*self.k)-n_filtered)/dev_.nelement()
        if log and str(h.device) == 'cuda:0':
            print("thr:" + str(self.thr_v)+", filtered/to filter: "+str(n_filtered)+"/"+str(int(dev_.nelement()*self.k))+", MS:"+str(torch.sum(dev_*update_idx).item()))
            
        return L2E_loss

