import torch as t
import torch.nn as nn
import torch.nn.functional as F

import algs.list.vanilla as  base
from algs.utils.penultimate import *
from algs.utils.HSIC import *



class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        if 'final' in self.args.job:
            ce = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.args.epoch,ce)

    def train(self,loaders,epoch,ce):
        b_model = self.models['bias']['net']
        b_opt = self.models['bias']['opt']
        b_scheduler = self.models['bias']['scheduler']
        b_pen = gen_pen(b_model,break_layer='fc')
        d_model = self.models['final']['net']
        d_opt = self.models['final']['opt']
        d_scheduler = self.models['final']['scheduler']
        d_pen = gen_pen(d_model,break_layer='fc')

        b_HSIC = MinusRbfHSIC()
        d_HSIC = RbfHSIC()

        best_res = {'acc': 0, 'loss': float('inf')}

        for e in range(epoch):
            b_model.train()
            d_model.train()

            for _, data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)

                logit = b_model(x)
                loss = ce(logit,y) + b_HSIC(b_pen(x),d_pen(x))
                b_opt.zero_grad()
                loss.mean().backward()
                b_opt.step()
            
            b_scheduler.step()


            for _, data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)

                logit = d_model(x)
                loss = ce(logit,y) + d_HSIC(b_pen(x),d_pen(x))
                d_opt.zero_grad()
                loss.mean().backward()
                d_opt.step()
            
            d_scheduler.step()

            self.statistics(e, epoch, d_model, loaders, best_res)
        self.model_save()
