import numpy as np
import torch
from torch.utils.data import Subset, DistributedSampler, SubsetRandomSampler
import math
import torch.distributed as dist

def pruning_indices(dataset, ratio=0.5, epoch=0, num_epochs=None, delta=0.875):
    if num_epochs is not None and epoch > num_epochs * delta:
        dataset.weights = np.ones(len(dataset))
        return list(range(len(dataset)))

    scores = dataset.scores
    mean_score = scores.mean()
    well_learned_mask = scores < mean_score
    well_learned_indices = np.where(well_learned_mask)[0]
    poorly_learned_indices = np.where(~well_learned_mask)[0]
    
    well_learned_total = len(well_learned_indices)
    poorly_learned_total = len(poorly_learned_indices)
    
    
    well_learned_selected_count = int(ratio * well_learned_total)
    if well_learned_selected_count > 0:
        rng = np.random.default_rng(epoch)
        selected_well_indices = rng.choice(well_learned_indices, well_learned_selected_count, replace=False)
    else:
        selected_well_indices = np.array([], dtype=int)
    
    
    step1_well_count = len(selected_well_indices)
    step1_poor_count = poorly_learned_total
    step1_total = step1_well_count + step1_poor_count
    
    
    final_total = int(ratio * len(dataset))
    
    
    if step1_total > final_total:
        
        if step1_total > 0:
            well_ratio = step1_well_count / step1_total
            poor_ratio = step1_poor_count / step1_total
            
            final_well_count = int(final_total * well_ratio)
            final_poor_count = final_total - final_well_count
            
            
            rng = np.random.default_rng(epoch + 1000)  
            
            if final_well_count > 0 and len(selected_well_indices) > 0:
                final_well_indices = rng.choice(selected_well_indices, min(final_well_count, len(selected_well_indices)), replace=False)
            else:
                final_well_indices = np.array([], dtype=int)
                
            if final_poor_count > 0 and len(poorly_learned_indices) > 0:
                final_poor_indices = rng.choice(poorly_learned_indices, min(final_poor_count, len(poorly_learned_indices)), replace=False)
            else:
                final_poor_indices = np.array([], dtype=int)
        else:
            final_well_indices = np.array([], dtype=int)
            final_poor_indices = np.array([], dtype=int)
    else:
        
        final_well_indices = selected_well_indices
        final_poor_indices = poorly_learned_indices
    
    
    final_indices = np.concatenate([final_poor_indices, final_well_indices])
    
    
    weights = np.ones(len(dataset))
    if len(final_well_indices) > 0:
        weights[final_well_indices] = well_learned_total / len(final_well_indices)
    if len(final_poor_indices) > 0:
        weights[final_poor_indices] = poorly_learned_total / len(final_poor_indices)
    
    dataset.weights = weights
    
    print(f"Kept samples: {len(final_indices)}/{len(dataset)} ({len(final_indices)/len(dataset)*100:.2f}%)")
    print(f"Step1: well={step1_well_count}/{well_learned_total}, poor={step1_poor_count}/{poorly_learned_total}, total={step1_total}")
    print(f"Final: well={len(final_well_indices)}, poor={len(final_poor_indices)}, target={final_total}")
    return final_indices.tolist()


class PruningSampler(DistributedSampler):
    def __init__(self, dataset, ratio=0.5, num_epochs=None, delta=0.875, **kwargs):
        super().__init__(dataset, **kwargs)
        self.ori_dataset = dataset
        self.baseratio = ratio
        self.ratio = ratio
        self.num_epochs = num_epochs
        self.delta = delta
        self.current_epoch = None

    def _set_dataset_and_stats(self, dataset):
        self.dataset = dataset
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas

    def set_epoch(self, epoch):
        

        if self.current_epoch != epoch:
            keep_indices = pruning_indices(
                self.ori_dataset, self.ratio, epoch, self.num_epochs, self.delta
            )
            pruned_subset = Subset(self.ori_dataset, keep_indices)
            self._set_dataset_and_stats(pruned_subset)
            self.current_epoch = epoch


            
        super().set_epoch(epoch)



class PrevPruningSampler(SubsetRandomSampler):
    def __init__(self, dataset, ratio=0.5, num_epochs=None, delta=0.875):
        self.ori_dataset = dataset
        self.baseratio = ratio
        self.ratio = ratio
        self.num_epochs = num_epochs
        self.delta = delta
        self.current_epoch = None
        n = len(self.ori_dataset)
        super().__init__(list(range(n)))

    def set_epoch(self, epoch):
        if self.current_epoch != epoch:
            keep_indices = pruning_indices(
                self.ori_dataset, self.ratio, epoch, self.num_epochs, self.delta
            )
            self.indices = keep_indices
            self.current_epoch = epoch










def _print_stratified_sampling_stats(sampled_indices, n, N, P_sizes, Q):
    print(f"Stratified kept samples: {len(sampled_indices)}/{n} ({len(sampled_indices)/n*100:.2f}%)")
    print(f"Layers: {N+1}")
    print("Orig:   " + " ".join(f"{p:6d}" for p in P_sizes))
    print("Sample: " + " ".join(f"{q:6d}" for q in Q))
    print("Ratio:  " + " ".join(f"{(q/p if p>0 else 0):6.2f}" for p, q in zip(P_sizes, Q)))

def stratified_pruning_indices(dataset, ratio=0.5, epoch=0, num_epochs=None, delta=0.875, c=1.0):
    if num_epochs is not None and epoch > num_epochs * delta:
        dataset.weights = np.ones(len(dataset))
        return list(range(len(dataset)))
    if epoch == 0:
        dataset.weights = np.ones(len(dataset)) 
        return list(range(len(dataset)))
    
    scores = np.array(dataset.scores)
    n = len(scores)
    H = scores.mean()
    N = int(np.ceil(np.log2(n)))
    layers = []
    mask_0 = scores <= H
    layers.append(np.where(mask_0)[0])
    for j in range(1, N+1):
        lower = 2**(j-1) * H
        upper = 2**j * H
        mask_j = (scores > lower) & (scores <= upper)
        layers.append(np.where(mask_j)[0])
    P_sizes = [len(idx) for idx in layers]
    props = np.array([(H + 2**(1-j) *c)**2 for j in range(0, N+1)])
    Q = _proportional_bucket_sampling(P_sizes, props, int(round(ratio * n)))
    sampled_indices = []
    weights = np.ones(n)
    rng = np.random.default_rng(epoch)
    for idxs, q, p in zip(layers, Q, P_sizes):
        if q > 0 and p > 0:
            if q < p:
                chosen = rng.choice(idxs, q, replace=False)
            else:
                chosen = idxs
            sampled_indices.extend(chosen)
            weights[chosen] = p / q if q > 0 else 1.0
    sampled_indices = np.array(sampled_indices)
    
    dataset.weights = weights
    _print_stratified_sampling_stats(sampled_indices, n, N, P_sizes, Q)
    return sampled_indices.tolist()


class StratifiedSampler(DistributedSampler):
    def __init__(self, dataset, ratio=0.5, num_epochs=None, delta=0.875, c=1.0, **kwargs):
        super().__init__(dataset, **kwargs)
        self.ori_dataset = dataset
        self.baseratio = ratio
        self.ratio = ratio
        self.num_epochs = num_epochs
        self.delta = delta
        self.current_epoch = 0
        self.c = c

    def _set_dataset_and_stats(self, dataset):
        self.dataset = dataset
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas

    def set_epoch(self, epoch):
        

        if self.current_epoch != epoch:
            
            keep_indices = stratified_pruning_indices(
                self.ori_dataset, self.ratio, epoch, self.num_epochs, self.delta, self.c
            )
            pruned_subset = Subset(self.ori_dataset, keep_indices)
            self._set_dataset_and_stats(pruned_subset)
            self.current_epoch = epoch



        super().set_epoch(epoch)


def _proportional_bucket_sampling(P_sizes, props, target_total):
    
    P_sizes = np.array(P_sizes)
    props = np.array(props)


    props = np.where(P_sizes == 0, 0, props)
    
    
    props_sum = props.sum()
    if props_sum == 0:
        
        Q = np.full_like(P_sizes, target_total // len(P_sizes), dtype=int)
        remainder = target_total % len(P_sizes)
    else:
        Q = np.floor(props / props_sum * target_total).astype(int)
        remainder = target_total - Q.sum()
    
    
    if remainder > 0:
        Q[:remainder] += 1

    print(f">>> Try to allocate Q: {Q}")
    
    
    excess = np.maximum(0, Q - P_sizes)
    Q = np.minimum(Q, P_sizes) 
    
    
    total_excess = excess.sum()
    if total_excess > 0:
        remaining_capacity = P_sizes - Q
        total_remaining = remaining_capacity.sum()
        if total_remaining > 0:
            
            excess_alloc = np.floor(remaining_capacity / total_remaining * total_excess).astype(int)
            Q += excess_alloc
            
            
            remaining_excess = total_excess - excess_alloc.sum()
            if remaining_excess > 0:
                
                available_mask = Q < P_sizes
                available_indices = np.where(available_mask)[0]
                
                num_to_assign = min(remaining_excess, len(available_indices))
                Q[available_indices[:num_to_assign]] += 1
    
    return Q




class PrevStratifiedSampler(SubsetRandomSampler):
    def __init__(self, dataset, ratio=0.5, num_epochs=None, delta=0.875, c=1.0, trainer=None):
        self.ori_dataset = dataset
        self.baseratio = ratio
        self.ratio = ratio
        self.num_epochs = num_epochs
        self.delta = delta
        self.current_epoch = 0
        self.c = c
        self._trainer = trainer  
        n = len(self.ori_dataset)
        
        super().__init__(list(range(n)))

    def set_epoch(self, epoch):
        if self.current_epoch != epoch:
            
            if self._trainer and hasattr(self._trainer, 'dynamic_sampling_factor'):
                ratio = (self._trainer.dynamic_sampling_factor * self.baseratio)
                if isinstance(ratio, torch.Tensor):
                    ratio = ratio.item()
                self.ratio = ratio
                print(f">>> Updated ratio from trainer: {self.ratio:.2%}")
            
            keep_indices = stratified_pruning_indices(
                self.ori_dataset, self.ratio, epoch, self.num_epochs, self.delta, self.c
            )
            self.indices = keep_indices
            self.current_epoch = epoch


class RandomSubsetSampler(DistributedSampler):
    def __init__(self, dataset, ratio=0.5, num_epochs=None, delta=0.875, **kwargs):
        super().__init__(dataset, **kwargs)
        self.ori_dataset = dataset
        self.baseratio = ratio
        self.ratio = ratio
        self.num_epochs = num_epochs
        self.delta = delta
        self.current_epoch = None

    def _set_dataset_and_stats(self, dataset):
        self.dataset = dataset
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas
        print(f"{self.num_replicas=},{self.total_size=},{self.num_samples=}")

    def set_epoch(self, epoch):
        if self.current_epoch != epoch:
            n = len(self.ori_dataset)
            if self.num_epochs is not None and epoch > self.num_epochs * self.delta:
                self.ori_dataset.weights = np.ones(len(self.ori_dataset))
                selected_indices = list(range(len(self.ori_dataset)))
            else:
                num_select = int(self.ratio * n)
                rng = np.random.default_rng(epoch)
                if num_select > 0:
                    selected_indices = rng.choice(np.arange(n), num_select, replace=False)
                else:
                    selected_indices = np.array([], dtype=int)
            subset = Subset(self.ori_dataset, selected_indices)
            print(f"RandomSubsetSampler kept samples: {len(selected_indices)}/{n} ({(len(selected_indices)/n*100 if n>0 else 0):.2f}%)")
            self._set_dataset_and_stats(subset)
            self.current_epoch = epoch
        super().set_epoch(epoch)

class PrevRandomSubsetSampler(SubsetRandomSampler):
    def __init__(self, dataset, ratio=0.5, num_epochs=None, delta=0.875, seed=0):
        self.ori_dataset = dataset
        self.baseratio = ratio
        self.ratio = ratio
        self.num_epochs = num_epochs
        self.delta = delta
        self.current_epoch = None
        self.seed = seed
        n = len(self.ori_dataset)
        super().__init__(list(range(n)))

    def set_epoch(self, epoch):
        if self.current_epoch != epoch:
            n = len(self.ori_dataset)
            if self.num_epochs is not None and epoch > self.num_epochs * self.delta:
                self.ori_dataset.weights = np.ones(len(self.ori_dataset))
                indices = list(range(n))
            else:
                num_select = int(self.ratio * n)
                rng = np.random.default_rng(self.seed + epoch)
                if num_select > 0:
                    indices = rng.choice(np.arange(n), num_select, replace=False)
                else:
                    indices = np.array([], dtype=int)
            self.indices = indices
            print(f"PrevRandomSubsetSampler kept samples: {len(indices)}/{n} ({(len(indices)/n*100 if n>0 else 0):.2f}%)")
            self.current_epoch = epoch
