import torch as t
import torch.nn as nn
import torch.nn.functional as F

import algs.list.vanilla as  base


class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        if 'final' in self.args.job:
            repair_args = {
                'w_lr' : 10.0,
                'portion' : 0.5,
                'method' : 'hard'
            }
            ce = nn.CrossEntropyLoss(reduction='none')
            self.models['weight'] = self.weight_param(len(self.loaders['train'].dataset.label),repair_args['w_lr'])
            
            self.b_train(self.loaders,self.models,self.args.epoch,ce,repair_args)
            self.sample_selection(self.loaders['train'],self.models['weight']['param'],repair_args)
            self.train(self.loaders,self.models['final'],self.args.epoch,ce)
        
    def sample_selection(self,loader,weight,repair_args):
        if repair_args['method'] == 'hard':
            candidate = weight.sort(descending=True)[1][:int(len(loader.dataset.label)*repair_args['portion'])]
            loader.dataset.refine_dataset(candidate)
        elif repair_args['method'] == 'soft':
            prob = t.sigmoid(weight)
            prob /= t.sum(prob)
            loader.dataset.prob_update(prob)
            loader.dataset.sample_on()
        
        


    def b_train(self,loaders,model,epoch,ce,repair_args):
        w_param = model['weight']['param']
        w_opt = model['weight']['opt']
        b_model = model['bias']['net']
        b_opt = model['bias']['opt']
        b_scheduler = model['bias']['scheduler']
        
        cls_idx = t.stack([loaders['train'].dataset.label == c for c in range(self.args.num_labels)]).float().cuda()

        best_res = {'acc': 0, 'loss': float('inf')}

        for e in range(epoch):
            b_model.train()
            
            for _, data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)
                i = data[4].to(self.args.device)
                w = t.sigmoid(w_param)
                z = w[i] / w.mean()
                cls_w = cls_idx @ w
                q = (cls_w / cls_w.sum())

                logit = b_model(x)
                loss = (ce(logit,y)*z).mean()
                b_opt.zero_grad()
                loss.backward(retain_graph=True)
                b_opt.step()

                b_model.eval()
                logit = b_model(x)
                loss = (ce(logit,y)*z).mean()
                entropy = -(q[y].log()*z).mean()
                loss_w = (1-loss/entropy)
                w_opt.zero_grad()
                loss_w.backward()
                w_opt.step()
            b_scheduler.step()


            self.statistics(e, epoch, b_model, loaders, best_res)
        self.model_save()
                


    
    def weight_param(self,data_len,lr):
        w_param = nn.Parameter(t.zeros(data_len).cuda())
        w_optim = t.optim.SGD([w_param],lr=lr)
        return {'param':w_param,'opt':w_optim}