import torch as t
import torch.nn as nn
import torch.nn.functional as F
import os
from utils.func import *

class algorithm:
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()
        
    def run(self):
        # Train final model
        if 'final' in self.args.job:
            cr = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.models['final'],self.args.epoch, cr)
        
    def train(self,loaders,models,epoch,cr,validation=True):
        model = models['net']
        opt = models['opt']
        scheduler =  models['scheduler']
        best_res = {'acc': 0, 'loss': float('inf')}
        for e in range(epoch):
            model.train()
            for _,  data in enumerate(loaders['train']):
                x = data[0].to(self.args.device)
                y = data[1].to(self.args.device)
                logit = model(x)
                loss = cr(logit, y).mean()
                opt.zero_grad()
                loss.backward()
                opt.step()
            scheduler.step()
            if validation == True:
                self.statistics(e, epoch, model, loaders, best_res)
        self.model_save()


    def noise_fix(self, models, loader, device):
        # Inferred Label
        model = models['net']
        model.eval()
        cr = nn.CrossEntropyLoss(reduction='none')
        idx = []
        loss = []
        entropy = []
        for _, data in enumerate(loader):
            x = data[0].to(device)
            y = data[1].to(device)
            I = data[4]
            logit = model(x)
            smx = t.softmax(logit,dim=1).detach().cpu()
            entropy.extend(-t.sum(smx * t.log(smx),dim=1))
            loss.extend(cr(logit,y).detach().cpu())
            idx.extend(I.cpu())
        
        # Indexing
        idx = t.tensor(idx).sort()[1]
        entropy = t.tensor(entropy)[idx]
        loss = t.tensor(loss)[idx]

        # Normalizing
        entropy /= t.max(entropy)
        loss /= t.max(loss)
        clean_candidates = t.where( loss / entropy < t.mean(loss)/t.mean(entropy))
        loader.dataset.refine_dataset(clean_candidates)
        


    def denoising(self):
        if 'noise' in self.args.job:
            cr = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders,self.noise_model,self.args.epoch,cr)
        self.noise_fix(self.noise_model,self.loaders['train'], self.args.device)
        

    def model_save(self,opt='end'):
        if self.args.save =='True':
            model_save(self.models, self.args.ckpt_path,opt)
            if 'noise' in self.args.job:
                model_save(self.noise_model, self.args.out_path,'noise_'+opt)

    def statistics(self, e, epoch, model,loaders,best_res):
        tr_res = self.evaluate(model, loaders['train'])
        val_res = self.evaluate(model, loaders['val'])
        if val_res['acc'] > best_res['acc']:
            best_res['acc'] = val_res['acc']
            best_res['loss'] = val_res['loss']
            if self.args.save == 'True':
                self.model_save(opt='best')
        self.res_summary(e+1, epoch, {'train': tr_res, 'valid': val_res, 'best': best_res})

    def res_summary(self,e,epoch,dict):
        self.args.log.info('-'*50)
        for key1 in dict.keys():
            output = "[%3d/%3d] %s \t-- " %(e, epoch, key1)
            for key2 in dict[key1].keys():
                output += "%s: %.3f " %(key2, dict[key1][key2])
            self.args.log.info(output)
        self.args.log.info('-'*50)

    def evaluate(self, model, loader):
        model.eval()
        ret_corr = []
        ret_loss = []
        cr = nn.CrossEntropyLoss(reduction='none')
        for _, data in enumerate(loader):
            x, y = data[0].to(self.args.device), data[1].to(self.args.device)
            logit = model(x)
            loss = cr(logit, y)
            ret_corr.extend((logit.max(1)[1] == y).float().detach().cpu())
            ret_loss.extend((loss).detach().cpu())
        
        return {'acc':t.mean(t.tensor(ret_corr))  * 100,
                'loss':t.mean(t.tensor(ret_loss))}


