import math
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from  preprocessing_utils import get_fast5_file, get_read
from torch.utils.data import Sampler


class ReadDataset(Dataset):
    
    def __init__(self, info_df, 
                 signal_length=5200, min_read_counts=20, max_read_counts=100,
                 transcripts=None):
        
        self.info_df = info_df
        self.min_read_counts = min_read_counts
        self.max_read_counts = max_read_counts
        self.transcripts = transcripts
        self.signal_length = signal_length
        
        if self.transcripts is not None:
            self.info_df = self.info_df[self.info_df["transcript_id"].isin(transcripts)]

        if self.min_read_counts > 0:
            self.info_df = self.info_df.groupby(["transcript_id"])\
                .filter(lambda x: len(x) > self.min_read_counts)
        
        if self.max_read_counts > 0:
            self.info_df = self.info_df.groupby("transcript_id")\
                .apply(lambda x: x.iloc[np.random.choice(len(x), self.max_read_counts, replace=False)]
                       if len(x) > self.max_read_counts else x)
            self.info_df = self.info_df.reset_index().drop(columns="index")
        
        self.transcripts = self.info_df.transcript_id.unique()
        
    def __getitem__(self, idx):
        read_id, fpath, tx_id, is_reverse = self.info_df.iloc[idx]\
                        [["read_id", "fpath", "transcript_id", "is_reverse"]]
        
        signals = torch.stack([torch.nn.functional.pad(x, (0, self.signal_length - len(x)))
                      for x in torch.tensor(get_read(read_id, get_fast5_file(fpath)))\
                                   .split(self.signal_length)])

        return signals, read_id, tx_id, is_reverse

    def __len__(self):
        return len(self.info_df)
    
    def set_distributed_mode(self, world_size, rank):
        transcripts_split = np.array_split(self.transcripts, world_size)
        self.transcripts = transcripts_split[rank]
        self.info_df = self.info_df[self.info_df["transcript_id"].isin(self.transcripts)]


def collate_fn(batches):
    signals = [batch[0] for batch in batches]
    lengths = np.array([len(batch[0]) for batch in batches])
    read_ids = [batch[1] for batch in batches]
    transcripts = np.array([batch[2] for batch in batches])
    is_reverse = [batch[3] for batch in batches]
    return signals, read_ids, transcripts, lengths, is_reverse
