import torch as t
import torch.nn as nn
import torch.nn.functional as F
from utils.func import *

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import algs.list.vanilla as base
from algs.model import get_classifier

class reduced_set(Dataset):
    def __init__(self,img,label,list_idx):
        self.idxs = list_idx
        if len(img.shape) == 3:
            self.img = img[list_idx].unsqueeze(1)
        else:
            self.img = img[list_idx]
        self.label = label[list_idx]
        
    def __getitem__(self,idx):
        return self.img[idx], self.label[idx], self.idxs[idx]
    
    def __len__(self):
        return len(self.label)

class algorithm(base.algorithm):
    def __init__(self, models, noise_model, loaders, args):
        self.models = models
        self.noise_model = noise_model
        self.loaders = loaders
        self.args = args
        if self.args.denoise == 'True':
            self.denoising()

    def run(self):
        if 'final' in self.args.job:
            aflite_args = {
                'inner_train' : 30,
                'inner_epoch' : 50,
                'inner_portion' : 0.2,
                'final_portion' : 0.5
                }
            idxs = self.split_train(self.loaders['train'],aflite_args)
            self.loaders['train'].dataset.refine_dataset(idxs)
            cr = nn.CrossEntropyLoss(reduction='none')
            self.train(self.loaders, self.models['final'],self.args.epoch, cr)
        
    def split_train(self,loader,aflite_args):
        history = t.zeros(len(loader.dataset.label))
        for ii in range(aflite_args['inner_train']):
            model = get_classifier(self.args)
            split = self.data_split(loader,aflite_args['inner_portion'])
            cr = nn.CrossEntropyLoss(reduction='none')
            self.train(split, model, aflite_args['inner_epoch'], cr, validation=False)
            corr, idxs = self.get_history(model, split['test'])
            history[idxs] += corr

            self.args.log.debug('[%2d/ %2d] Inner train Done...' %(ii+1, aflite_args['inner_train']))

            return history.sort()[1][:int(len(history)*aflite_args['final_portion'])]

    def data_split(self,loader,portion):
        data_len = len(loader.dataset.label)
        perm = t.randperm(data_len)
        idx_tr = perm[:int(portion*data_len)]
        idx_te = perm[int(portion*data_len):]

        data_tr = reduced_set(loader.dataset.imgs, loader.dataset.label, idx_tr)
        data_te = reduced_set(loader.dataset.imgs, loader.dataset.label, idx_te)

        tr_loader = DataLoader(dataset=data_tr,batch_size=self.args.batch, shuffle=True,num_workers=8)
        te_loader = DataLoader(dataset=data_te,batch_size=self.args.batch, shuffle=True,num_workers=8)

        return {'train': tr_loader, 'test': te_loader}
        
    def get_history(self,model, loader):
        corr,idxs = [],[]
        model = model['net']
        model.eval()
        for _, data in enumerate(loader):
            x = data[0].to(self.args.device)
            y = data[1].to(self.args.device)
            i = data[2]
            logit = model(x)
            corr.extend((logit.max(1)[1] == y).detach().cpu())
            idxs.extend(i.cpu())
        corr = t.tensor(corr).long()
        idxs = t.tensor(idxs).long()
        return corr, idxs


