import torch as t
import torch.nn as nn
import torch.nn.functional as F
from utils.func import *

import algs.list.vanilla as base
from algs.utils.gce import GeneralizedCELoss



class moving_average:
    def __init__(self, alpha, num_labels, data):
        self.alpha = alpha
        self.num_labels = num_labels
        self.loss = t.zeros(len(data.dataset.label))
        try:
            self.cidx = data.dataset.label.numpy()
        except:
            self.cidx = data.dataset.label
        

    def update(self,loss,idx):
        alpha = 1.0 if t.sum(self.loss[idx]) == 0 else self.alpha
        self.loss[idx] = alpha * loss + (1-alpha) *self.loss[idx]
    
    def get_max(self,c):
        pos = np.where(self.cidx == c)[0]
        return t.max(self.loss[pos])
    
    def extract(self,loss,idx):
        for c in range(self.num_labels):
            pos = np.where(self.cidx[idx] == c)
            loss[pos] /= self.get_max(c)
        return loss


class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        if 'final' in self.args.job:
            gce = GeneralizedCELoss()
            ce = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.args.epoch,gce,ce)
    
    def train(self,loaders,epoch,gce,ce):
        b_model = self.models['bias']['net']
        b_opt = self.models['bias']['opt']
        b_scheduler = self.models['bias']['scheduler']
        d_model = self.models['final']['net']
        d_opt = self.models['final']['opt']
        d_scheduler = self.models['final']['scheduler']
        b_ma = moving_average(0.7, self.args.num_labels, self.loaders['train'])
        d_ma = moving_average(0.7, self.args.num_labels, self.loaders['train'])


        best_res = {'acc': 0, 'loss': float('inf')}
        
        for e in range(epoch):
            b_model.train()
            d_model.train()

            for _, data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)
                i = data[4]
                b_logit = b_model(x)
                d_logit = d_model(x)

                b_loss = ce(b_logit,y)
                d_loss = ce(d_logit,y)

                b_ma.update(b_loss.detach().cpu(),i)
                d_ma.update(d_loss.detach().cpu(),i)

                b_loss = b_ma.extract(b_loss.detach().cpu(),i)
                d_loss = d_ma.extract(d_loss.detach().cpu(),i)

                weight = b_loss / (b_loss + d_loss + 1e-6)

                b_update = gce(b_logit,y)
                d_update = ce(d_logit,y) * weight.to(self.args.device)

                loss = b_update.mean() + d_update.mean()

                b_opt.zero_grad()
                d_opt.zero_grad()
                loss.mean().backward()
                b_opt.step()
                d_opt.step()
            b_scheduler.step()
            d_scheduler.step()

            self.statistics(e, epoch, d_model, loaders, best_res)
        self.model_save()