import torch as t
import torch.nn as nn
import torch.nn.functional as F

import algs.list.vanilla as  base
from algs.utils.gce import GeneralizedCELoss


class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        # train bias model
        if 'bias' in self.args.job:
            cr = GeneralizedCELoss()
            self.train(self.loaders,self.models['bias'],self.args.epoch,cr)

        self.bias_fix(self.models['bias'],self.loaders['train'], self.args.device)
        if 'final' in self.args.job:
            cr = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.models['final'],self.args.epoch,cr)


    def bias_fix(self,models, loader, device):
        # Noisy label fix
        model_bias = models['net']
        model_bias.eval()

        # Bias fix using gradient
        grads, cidx, _ = self.gradient(model_bias, loader, device)
        grads -= t.mean(grads,dim=0) # Centering
        grads = self.grads_pick(grads)
        mag_prob = self.magnitude(grads,cidx)
        ang_prob = self.angle(grads,cidx)
        self.mag_prob = 1./mag_prob / t.sum(1./mag_prob)
        self.ang_prob = 1./ang_prob / t.sum(1./ang_prob)
        self.tot_prob = self.weight_total(self.mag_prob,self.ang_prob)

        # Print
        corr_maj = t.where((loader.dataset.label==loader.dataset.gt_label)&(loader.dataset.midx==0))
        corr_min = t.where((loader.dataset.label==loader.dataset.gt_label)&(loader.dataset.midx==1))
        inco_maj = t.where((loader.dataset.label!=loader.dataset.gt_label)&(loader.dataset.midx==0))
        inco_min = t.where((loader.dataset.label!=loader.dataset.gt_label)&(loader.dataset.midx==1))
        self.args.log.debug('-'*80)
        self.args.log.debug('*** Sampling probability ***')
        self.args.log.debug("mag:  correct_maj %.4f /  correct_min %.4f / incorrect_maj %.4f /  incorrect_min %.4f " \
            %(t.sum(self.mag_prob[corr_maj]),t.sum(self.mag_prob[corr_min]),t.sum(self.mag_prob[inco_maj]),t.sum(self.mag_prob[inco_min])))
        self.args.log.debug("ang:  correct_maj %.4f /  correct_min %.4f / incorrect_maj %.4f /  incorrect_min %.4f " \
            %(t.sum(self.ang_prob[corr_maj]),t.sum(self.ang_prob[corr_min]),t.sum(self.ang_prob[inco_maj]),t.sum(self.ang_prob[inco_min])))
        self.args.log.debug("tot:  correct_maj %.4f /  correct_min %.4f / incorrect_maj %.4f /  incorrect_min %.4f " \
            %(t.sum(self.tot_prob[corr_maj]),t.sum(self.tot_prob[corr_min]),t.sum(self.tot_prob[inco_maj]),t.sum(self.tot_prob[inco_min])))
        self.args.log.debug('-'*80)
        
        loader.dataset.prob_update(self.tot_prob)
        loader.dataset.sample_on()

    
    def magnitude(self,grads,cidx,cla=True):
        mag = t.zeros(len(grads))
        for idx in range(t.max(cidx)+1):
            pos = t.where(cidx == idx)
            x = grads[pos]
            mag_tmp = 1./t.norm(x,p=None,dim=1,keepdim=False)
            mag[pos] = mag_tmp / t.sum(mag_tmp)
        return mag / t.sum(mag)
        
    def angle(self,grads,cidx):
        ret = t.zeros(len(grads))
        for idx in range(t.max(cidx)+1):
            pos = t.where(cidx == idx)
            norm = t.norm(grads[pos],p=None,dim=1,keepdim=True)
            x = grads[pos] / norm
            u = t.sum(x,dim=0)
            u_norm = t.norm(u,p=None,dim=0,keepdim=False) 
            d = x.size()[1]
            r = u_norm / len(x)
            k = r * (d - r ** 2 ) / (1 - r ** 2)
            ret_tmp =  t.clamp(t.exp(k * t.matmul(u / u_norm , x.T)),min=1e-9,max=t.exp(t.tensor(20.)))
            ret[pos] = ret_tmp / t.sum(ret_tmp)
            
        return ret / t.sum(ret)

    def weight_total(self,m,a,alpha=1.0):
        tmp = m+ alpha * a
        return (tmp) / t.sum(tmp)

    def grads_pick(self,grads,K=100):
        dominant_theta = t.mean(t.abs(grads),dim=0)
        argsort = t.argsort(dominant_theta,descending=True)
        pick = argsort[:K]
        grads = grads[:,pick]
        return grads


    def gradient(self, model, loader, device):
        if self.args.arch == 'conv0':
            grad = gradient_extract(model=model,module=[model.layer1,model.layer2,model.layer3,model.fc], layer=["0"])
        elif self.args.arch == 'conv1':
            grad = gradient_extract(model=model,module=[model.layer1,model.layer2,model.layer3,model.layer4,model.fc], layer=["0"])
        elif 'resnet' in self.args.arch:
            grad = gradient_extract(model=model,module=[model.layer3,model.layer4,model.fc], layer=["0","1"])

        
        for idx, data in enumerate(loader):
            X = data[0]
            Y = data[1]
            M = data[2]
            I = data[4]
            x = X.to(device).requires_grad_(True)
            y = Y.to(device)
            output = grad(x)
            model.zero_grad()
            loss = nn.CrossEntropyLoss()(output,y)
            loss.backward()
            grads = grad.gradients.detach().cpu()
            if idx == 0:
                grads_size = grads.size(1)
                out = t.zeros((len(loader.dataset.label),grads_size),dtype=t.float32)
                cidx = t.zeros(len(loader.dataset.label),dtype=t.int64)
                midx = t.zeros(len(loader.dataset.label),dtype=t.int64)
            out[I],cidx[I],midx[I] = grads.float(),Y.long(),M.long()
        return out, cidx, midx


class gradient_extract:
    def __init__(self, model, module, layer):
        self.model = model
        self.module = module
        self.layer = layer
    
    def save_gradient(self,grad):
        grad = grad.reshape([len(grad),-1])
        if len(self.gradients) == 0:
            self.gradients = grad
        else:
            self.gradients = t.cat((self.gradients, grad),1)

    def __call__(self, x):
        self.gradients = []
        for name_m,module_m in self.model._modules.items():
            if module_m in self.module:
                if 'fc' in name_m.lower():
                    x = module_m(x)
                    x.register_hook(self.save_gradient)
                elif "avgpool" in name_m.lower():
                    x = module_m(x)
                    x.register_hook(self.save_gradient)
                elif "hidden" in name_m.lower():
                    x = module_m(x)
                    x.register_hook(self.save_gradient)
                else:
                    for name_l,module_l in module_m._modules.items():
                        x = module_l(x)
                        if name_l in self.layer:
                            x.register_hook(self.save_gradient)
            elif "avgpool" in name_m.lower():
                x = module_m(x)
                x = x.view(x.size(0),-1)
            else:
                x = module_m(x)
        return x