import torch as t
import torch.nn as nn
import torch.nn.functional as F

import algs.list.vanilla as  base
from algs.utils.new_fc import new_fc


class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        if 'final' in self.args.job:
            mixin_args = {
                'f_lr' : 0.3,
                'weight' : 0.34
            }
            ce = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.args.epoch,ce,mixin_args)
    
    def train(self,loaders,epoch,ce,mixin_args):
        b_model = self.models['bias']['net']
        b_opt = self.models['bias']['opt']
        b_scheduler = self.models['bias']['scheduler']
        d_model = self.models['final']['net']
        d_opt = self.models['final']['opt']
        d_scheduler = self.models['final']['scheduler']
        f_model = new_fc(b_model,self.args.device,self.args.num_labels)
        f_opt = t.optim.SGD(f_model.parameters(),lr = mixin_args['f_lr'])
        f_scheduler = t.optim.lr_scheduler.StepLR(f_opt, step_size=self.args.lr_decay_step , gamma=self.args.lr_decay)

        best_res = {'acc': 0, 'loss': float('inf')}

        for e in range(epoch):
            b_model.train()
            d_model.train()

            for _, data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)

                b_logit = b_model(x) # fpred
                d_logit = d_model(x) # gpred
                factor = f_model(x)
                factor = F.softplus(factor)
                d_logit *= factor

                loss = ce(d_logit + b_logit, y)
                bias_lp = F.log_softmax(d_logit,1)
                entropy = -(t.exp(bias_lp) * bias_lp).sum(1).mean()

                loss = (loss + mixin_args['weight'] * entropy).mean()

                f_opt.zero_grad()
                b_opt.zero_grad()
                d_opt.zero_grad()
                loss.backward()
                f_opt.step()
                b_opt.step()
                d_opt.step()
            
            b_scheduler.step()
            d_scheduler.step()
            f_scheduler.step()

            self.statistics(e, epoch, d_model, loaders, best_res)
        self.model_save()