import math
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import torch.distributed as dist
from torch.utils.data import Sampler, BatchSampler
from tqdm import tqdm
from .scorer import Llama_Scorer

class InfoBatch(Dataset):
    def __init__(self, dataset, data_method="random", data_ratio=0.5):
        self.dataset = dataset
        self.method = data_method
        self.ratio = data_ratio
        self.scores = np.ones([len(self.dataset)])
        self.complexities = np.zeros([len(self.dataset)])
        self.qualities = np.zeros([len(self.dataset)])
        self.embeds = np.ones((len(self.dataset), 1, 4096))
        self.grand = np.zeros([len(self.dataset)])
        self.el2n = np.zeros([len(self.dataset)])
        self.weights = np.ones(len(self.dataset))
        self.save_num = 0
        self.seq = list(range(len(self.dataset)))
        # self.filter = Combined_Filter(threshold=0.9, data_size=int(self.ratio*len(self.dataset)), sort_key="complexity_scores,quality_scores", chunk_size=100000, distance_metric="cosine", embedding_field="embedding")
        # debug flags (only print once per method on rank0)
        self.debug_longest_printed = False
        self.debug_entropy_printed = False

    def __setscore__(self, indices, values):
        count = torch.zeros(len(self.dataset), device="cuda")
        delta_scores = torch.zeros(len(self.dataset), device="cuda")
        count[indices] = 1
        delta_scores[indices] = torch.tensor(values, device="cuda")
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        dist.all_reduce(delta_scores, op=dist.ReduceOp.SUM)
        delta_scores /= torch.clamp(count, min=1)
        valid_mask = (delta_scores != 0).cpu()
        self.scores[valid_mask] = delta_scores[valid_mask].cpu().numpy()
    
    def __setcomplexity__(self, indices, values):
        count = torch.zeros(len(self.dataset), device="cuda")
        delta_scores = torch.zeros(len(self.dataset), device="cuda")
        count[indices] = 1
        delta_scores[indices] = torch.tensor(values, device="cuda")
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        dist.all_reduce(delta_scores, op=dist.ReduceOp.SUM)
        delta_scores /= torch.clamp(count, min=1)
        valid_mask = (delta_scores != 0).cpu()
        self.complexities[valid_mask] = delta_scores[valid_mask].cpu().numpy()
    
    def __setquality__(self, indices, values):
        count = torch.zeros(len(self.dataset), device="cuda")
        delta_scores = torch.zeros(len(self.dataset), device="cuda")
        count[indices] = 1
        delta_scores[indices] = torch.tensor(values, device="cuda")
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        dist.all_reduce(delta_scores, op=dist.ReduceOp.SUM)
        delta_scores /= torch.clamp(count, min=1)
        valid_mask = (delta_scores != 0).cpu()
        self.qualities[valid_mask] = delta_scores[valid_mask].cpu().numpy()

    def __setembed__(self, indices, values):
        count = torch.zeros(len(self.dataset), device="cuda")
        delta_scores = torch.zeros(((len(self.dataset), 1, 4096)), device="cuda")
        count[indices] = 1
        delta_scores[indices] = torch.tensor(values, device="cuda")
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        dist.all_reduce(delta_scores, op=dist.ReduceOp.SUM)
        idx = torch.nonzero(count > 1).squeeze(1)
        if len(idx) > 0:
            delta_scores[idx] = delta_scores[idx] / count[idx][:, None, None]
        valid_mask = (count != 0).cpu()
        self.embeds[valid_mask] = delta_scores[valid_mask].cpu().numpy()

    def __setgrand__(self, indices, values):
        count = torch.zeros(len(self.dataset), device="cuda")
        delta_scores = torch.zeros(len(self.dataset), device="cuda")
        count[indices] = 1
        delta_scores[indices] = torch.tensor(values, device="cuda")
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        dist.all_reduce(delta_scores, op=dist.ReduceOp.SUM)
        delta_scores /= torch.clamp(count, min=1)
        valid_mask = (delta_scores != 0).cpu()
        self.grand[valid_mask] = delta_scores[valid_mask].cpu().numpy()

    def __setel2n__(self, indices, values):
        count = torch.zeros(len(self.dataset), device="cuda")
        delta_scores = torch.zeros(len(self.dataset), device="cuda")
        count[indices] = 1
        delta_scores[indices] = torch.tensor(values, device="cuda")
        dist.all_reduce(count, op=dist.ReduceOp.SUM)
        dist.all_reduce(delta_scores, op=dist.ReduceOp.SUM)
        delta_scores /= torch.clamp(count, min=1)
        valid_mask = (delta_scores != 0).cpu()
        self.el2n[valid_mask] = delta_scores[valid_mask].cpu().numpy()

    def __setmodel__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.seq)

    def __getitem__(self, i):
        i_list = i.item() if isinstance(i, np.integer) else i
        return dict(input_ids=self.dataset[i_list]["input_ids"], labels=self.dataset[i_list]["labels"], text=self.dataset[i_list]["text"], indices=i, weights=self.weights[i])

    def prune(self):
        pruned_samples = []
        samples = list(range(len(self.dataset)))
        pruned_samples.extend(samples)

        return pruned_samples

    def pruning_sampler(self):
        return InfoBatchSampler(self)

    def no_prune(self):
        samples = list(range(len(self.dataset)))
        np.random.shuffle(samples)
        return samples

    def mean_score(self):
        return self.scores.mean()

    def normal_sampler_no_prune(self):
        return InfoBatchSampler(self.no_prune)

    def get_weights(self,indexes):
        return self.weights[indexes]

    def total_save(self):
        return self.save_num

    def reset_weights(self):
        self.weights = np.ones(len(self.dataset))



class InfoBatchSampler(Sampler):
    def __init__(self, infobatch_dataset, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            num_replicas = dist.get_world_size()
        if rank is None:
            rank = dist.get_rank()
        self.infobatch_dataset = infobatch_dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.shuffle = shuffle
        self.seq = None
        self.seed = 0
        self.reset()

    def reset(self):
        # np.random.seed(self.seed)
        # self.seed+=1
        self.seq = self.infobatch_dataset.prune()
        self.new_length = len(self.seq)

        self.num_samples = int(math.ceil(self.new_length / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

        self.seq += self.seq[:(self.total_size - len(self.seq))]  
        self.seq = self.seq[self.rank:self.total_size:self.num_replicas]
        np.random.shuffle(self.seq)
        self.ite = iter(self.seq)
        self.new_length = len(self.seq)

    def __next__(self):
        try:
            nxt = next(self.ite)
            return nxt
        except StopIteration:
            self.reset()
            raise StopIteration

    def __len__(self):
        return len(self.seq)

    def __iter__(self):
        self.ite = iter(self.seq)
        return self