import torch as t
import torch.nn as nn
import torch.nn.functional as F
from utils.func import *

import algs.list.vanilla as  base
from algs.utils.new_fc import new_fc




class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        if 'final' in self.args.job:
            rubi_args = {
                'f_lr' : 0.3,
                'weight' : 0.34
            }
            ce = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.args.epoch,ce,rubi_args)
    
    def train(self,loaders,epoch,ce,rubi_args):
        d_model = self.models['final']['net']
        d_opt = self.models['final']['opt']
        d_scheduler = self.models['final']['scheduler']
        b_model = new_fc(d_model,self.args.device,self.args.num_labels)
        b_opt = t.optim.SGD(b_model.parameters(),lr = rubi_args['f_lr'])
        b_scheduler = t.optim.lr_scheduler.StepLR(b_opt, step_size=self.args.lr_decay_step , gamma=self.args.lr_decay)
        best_res = {'acc': 0, 'loss': float('inf')}


        for e in range(epoch):
            b_model.train()
            d_model.train()

            for _, data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)

                b_logit = b_model(x)
                d_logit = d_model(x) * t.sigmoid(b_logit)
                b_loss = ce(b_logit,y)
                d_loss = ce(d_logit,y)

                loss = (b_loss * rubi_args['weight'] + d_loss).mean()

                d_opt.zero_grad()
                b_opt.zero_grad()
                loss.backward()
                d_opt.step()
                b_opt.step()
            
            d_scheduler.step()
            b_scheduler.step()

            self.statistics(e, epoch, d_model, loaders, best_res)
        self.model_save()
        