import math
import numpy as np
import copy
from torch.utils.data import Dataset

import torch

def has_length(dataset):
    """
    Checks if the dataset implements __len__() and it doesn't raise an error
    """
    try:
        return len(dataset) is not None
    except TypeError:
        # TypeError: len() of unsized object
        return False

class HistoricDataset(Dataset):
    def __init__(self, dataset: Dataset, budget_num_epochs: int, historic_beta: float, smooth_factor: float, sampling_generator, full_sample_freq: int = -1,
                 prune_ratio: float = 0.5, delta: float = 10, warmup=-1, dynamic_ratio=False, heuristic_ratio=False, count_discount=1, total_samples=None):
        self.dataset = dataset
        self.total_samples = total_samples
        self.num_samples = len(self.dataset)
        self.peak_prune_ratio = prune_ratio
        self.warmup = warmup
        if warmup == -1:
            self.num_epochs = budget_num_epochs/(1-prune_ratio)
        else:
            self.num_epochs = budget_num_epochs/(1+(1-prune_ratio))*2
        self.delta = delta
        self.historic_beta = historic_beta
        self.count_discount = count_discount
        self.smooth_factor = smooth_factor
        self.sampling_generator = sampling_generator
        self.full_sample_freq = full_sample_freq
        self.dynamic_ratio = dynamic_ratio
        self.heuristic_ratio = heuristic_ratio
        self.weights = torch.ones(len(self.dataset))
        self.num_pruned_samples = 0
        self.cur_batch_index = None

        self.history_approx_grad_sq = torch.zeros(self.num_samples, device="cuda")
        self.history_count = torch.zeros(self.num_samples, device="cuda", dtype=torch.int)
        self.correct = torch.zeros(self.num_samples, device="cuda")
        self.grad_sq_ema = 0
        self.grad_sq_ema_count = 0

    def __setscore__(self, indices, values):
        weights = self.weights[indices]

        sample_approx_grad = values.cuda()
        grad_sq_mean = (sample_approx_grad**2*weights).mean()
        self.grad_sq_ema = 0.99*self.grad_sq_ema + 0.01*grad_sq_mean
        self.grad_sq_ema_count += 1

        grad_sq_ema = self.grad_sq_ema/(1-0.99**self.grad_sq_ema_count)
        self.history_approx_grad_sq[indices] = self.historic_beta*self.history_approx_grad_sq[indices]+(1-self.historic_beta)*(sample_approx_grad**2/grad_sq_ema)

        self.history_count[indices] += torch.tensor(1, device="cuda")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        idx = int(index)
        data = {k:v for k,v in self.dataset[idx].items()}
        data.update({'sample_idx': index, 'weight':self.weights[idx]})
        return data

    def __getitems__(self,index):
        """possibly batched index"""
        if has_length(index):
            data = [{k:v for k,v in self.dataset[int(idx)].items()} for idx in index]
            for i,idx in enumerate(index):
                data[i]['sample_idx'] = idx
                data[i]['weight'] = self.weights[int(idx)]
        else: raise ValueError("infobatch __getitems__ got index with no length!")
        # print(data)
        return data

    def __get_wt_id__(self):
        return self.weights[self.last_indexes], self.last_indexes

    def prune(self):
        # prune samples that are well learned, rebalence the weight by scaling up remaining
        # well learned samples' learning rate to keep estimation about the same
        # for the next version, also consider new class balance
        if self.warmup == -1:
            keep_ratio = 1-self.peak_prune_ratio
        else:
            peak = int(self.num_epochs*self.warmup)
            if iterations <= peak:
                keep_ratio = (iterations * (1-self.peak_prune_ratio) + (peak-iterations) * 1)/peak
            else:
                keep_ratio = 1-self.peak_prune_ratio
                #keep_ratio = ( (self.num_epochs-iterations-1) * (1-self.peak_prune_ratio) + (iterations-peak) * 1 )/(self.num_epochs-peak-1)
            keep_ratio = min(max(keep_ratio, 0), 1)
        if self.dynamic_ratio:
            correct_ratio = self.correct.mean()
            keep_ratio = (1-correct_ratio) + correct_ratio*self.peak_prune_ratio

        sample_approx_grad_sq = self.history_approx_grad_sq
        sample_approx_grad_sq = sample_approx_grad_sq/(1-torch.pow(self.historic_beta, self.history_count))
        sample_approx_grad_sq[self.history_count==0] = 1

        c = self.history_count-self.history_count.min()
        sample_approx_grad_sq *= torch.pow(self.count_discount, c)

        sample_approx_grad_sq = (1-self.smooth_factor)*sample_approx_grad_sq + self.smooth_factor*sample_approx_grad_sq.mean()
        sample_approx_grad = torch.sqrt(sample_approx_grad_sq)
        sample_approx_grad = sample_approx_grad.double()

        if self.heuristic_ratio:
            ref = torch.sqrt(sample_approx_grad_sq.mean())
        else:
            lb = 0
            # TODO check this
            rb = sample_approx_grad.mean()/keep_ratio
            for _ in range(50):
                mid = (lb+rb)/2

                _prob = sample_approx_grad/mid
                _prob = torch.minimum(_prob, torch.tensor(1))

                if _prob.mean() < keep_ratio:
                    rb = mid
                else:
                    lb = mid
            ref = mid

        sample_prob = sample_approx_grad/ref
        sample_prob = torch.minimum(sample_prob, torch.tensor(1))
        keep_ratio = sample_prob.mean()

        # TODO check this
        self.weights = 1/sample_prob*keep_ratio

        selected_indices = np.arange(self.num_samples)[torch.rand(self.num_samples, generator=self.sampling_generator) < sample_prob.cpu()]

        # in-place shuffle
        np.random.shuffle(selected_indices)
        return selected_indices

    def pruning_sampler(self):
        return HistoricSampler(self)

    def no_prune(self):
        samples = list(range(len(self.dataset)))
        np.random.shuffle(samples)
        return samples

    def mean_score(self):
        return self.scores.mean()

    def normal_sampler_no_prune(self):
        return HistoricSampler(self.no_prune)

    def get_weights(self,indexes):
        return self.weights[indexes]

    def total_save(self):
        return self.save_num

    def reset_weights(self):
        self.weights = np.ones(len(self.dataset))



class HistoricSampler():
    def __init__(self, dataset):
        self.dataset = dataset
        self.iterations = 0
        self.sample_indices = None
        self.full_sample_freq = dataset.full_sample_freq
        self.seq = None

        self.samples = 0
        self.total_samples = self.dataset.total_samples
        self.delta = self.dataset.delta

    def reset(self):
        if self.samples > self.delta * self.total_samples:
            self.sample_indices = self.dataset.no_prune()
        else:
            self.sample_indices = self.dataset.prune()
        
        self.samples += len(self.sample_indices)
        self.iterations += 1

    def __len__(self):
        return len(self.sample_indices)

    def __iter__(self):
        self.reset()
        yield from self.sample_indices
