import math
import numpy as np
import copy
from torch.utils.data import Dataset

def has_length(dataset):
    """
    Checks if the dataset implements __len__() and it doesn't raise an error
    """
    try:
        return len(dataset) is not None
    except TypeError:
        # TypeError: len() of unsized object
        return False

class InfoBatch(Dataset):
    def __init__(self, dataset, ratio = 0.5, total_samples=None, delta = 0.875, correction=False):
        self.dataset = dataset
        self.ratio = ratio
        self.delta = delta
        self.total_samples = total_samples
        self.scores = np.ones([len(self.dataset)])
        self.weights = np.ones(len(self.dataset))
        self.save_num = 0
        self.correction = correction

    def __setscore__(self, indices, values):
        self.scores[indices.cpu()] = values

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        idx = int(index)
        data = {k:v for k,v in self.dataset[idx].items()}
        data.update({'sample_idx': index, 'weight':self.weights[idx]})
        return data

    def __getitems__(self,index):
        """possibly batched index"""
        # print('InfoBatch.__getitems__ called')
        if has_length(index):
            data = [{k:v for k,v in self.dataset[int(idx)].items()} for idx in index]
            for i,idx in enumerate(index):
                data[i]['sample_idx'] = idx
                data[i]['weight'] = self.weights[int(idx)]
        else: raise ValueError("infobatch __getitems__ got index with no length!")
        # print(data)
        return data

    def __get_wt_id__(self):
        return self.weights[self.last_indexes], self.last_indexes

    def prune(self):
        # prune samples that are well learned, rebalence the weight by scaling up remaining
        # well learned samples' learning rate to keep estimation about the same
        # for the next version, also consider new class balance

        b = self.scores<self.scores.mean()
        well_learned_samples = np.where(b)[0]
        pruned_samples = []
        pruned_samples.extend(np.where(np.invert(b))[0])
        selected = np.random.choice(well_learned_samples, int(self.ratio*len(well_learned_samples)),replace=False)
        self.reset_weights()
        self.weights[well_learned_samples] = 1 / self.ratio
        if len(selected)>0:
            pruned_samples.extend(selected)
        if self.correction:
            sample_prob = 1/self.weights
            self.weights *= sample_prob.mean()

        print('Cut {} samples for next iteration'.format(len(self.dataset)-len(pruned_samples)))
        self.save_num += len(self.dataset)-len(pruned_samples)
        np.random.shuffle(pruned_samples)
        return pruned_samples

    def pruning_sampler(self):
        return InfoBatchSampler(self)

    def no_prune(self):
        self.reset_weights()
        samples = list(range(len(self.dataset)))
        np.random.shuffle(samples)
        return samples

    def mean_score(self):
        return self.scores.mean()

    def normal_sampler_no_prune(self):
        return InfoBatchSampler(self.no_prune)

    def get_weights(self,indexes):
        return self.weights[indexes]

    def total_save(self):
        return self.save_num

    def reset_weights(self):
        self.weights = np.ones(len(self.dataset))



class InfoBatchSampler():
    def __init__(self, infobatch_dataset):
        self.infobatch_dataset = infobatch_dataset
        self.seq = None
        self.seed = 0
        self.samples = 0
        self.total_samples = infobatch_dataset.total_samples
        self.delta = infobatch_dataset.delta

    def reset(self):
        if self.samples > self.delta * self.total_samples:
            self.seq = self.infobatch_dataset.no_prune()
        else:
            self.seq = self.infobatch_dataset.prune()
        self.samples += len(self.seq)

    def __len__(self):
        return len(self.seq)

    def __iter__(self):
        self.reset()
        yield from self.seq
