import os
import pandas as pd
import numpy as np
import torch
import joblib
import json
from torch.utils.data import Dataset
from itertools import product
from collections import Iterable
from glob import glob
from random import randint
from torch.utils.data._utils.collate import default_collate
from torch.nn.functional import pad
from collections import Iterable
from itertools import product


nucleotide_dict = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

class NanoporeDS(Dataset):

    def __init__(self, root_dir, mode, min_reads=0, sequence_context=10, max_reads=None):
        
        self.root_dir = root_dir
        self.mode = mode
        self.transcripts = [x.split(".npy")[0] for x in os.listdir(root_dir) if ".npy" in x]
        self.metadata = joblib.load(os.path.join(root_dir, "metadata.joblib"))
        self.data_index = pd.read_csv(os.path.join(root_dir, "data.index")).set_index("tx_id")
        self.data_info = pd.read_csv(os.path.join(root_dir, "data.info"))
        self.max_reads = max_reads
        self.min_reads = min_reads

        if self.min_reads > 0:
            self.data_info = self.data_info[self.data_info["n_reads"] >= self.min_reads]
            self.data_info = self.data_info.reset_index(drop=True).set_index(["tx_id"]).sort_index()
            
    def __len__(self):
        return len(self.transcripts)

    def load(self, arr, key):
        start_idx, end_idx = self.metadata[key]
        return arr[:, start_idx:end_idx]

    def load_tx_index(self, tx):
        start_pos, end_pos = self.data_index.loc[tx][["start_pos", "end_pos"]]
        with open(os.path.join(self.root_dir, "tx_index.json")) as f:
            f.seek(start_pos, 0)
            json_str = f.read(end_pos - start_pos)
            pos_info = json.loads(json_str)
        return pos_info[tx] 

    def __getitem__(self, idx):
        tx = self.transcripts[idx]
        arr = np.load(os.path.join(self.root_dir, tx + ".npy"))
        signals = self.load(arr, 'signal')
        positions, pos_lengths = self.load(arr, 'positions').astype('int64'), \
            self.load(arr, 'positions_length').flatten().astype('int64')
        
        indices_dict = self.load_tx_index(tx)
        tx_info = self.data_info.loc[tx]

        if len(tx_info.shape) > 1:
            tx_positions = tx_info["tx_position"].values
            labels = tx_info["modification_status"].values
            tx_sequences = tx_info["sequence"].values
        else:
            tx_positions = [tx_info["tx_position"]]
            labels = [tx_info["modification_status"]]
            tx_sequences = [tx_info["sequence"]]
        
        sequences = []
        for seq in tx_sequences:
            sequences.append([nucleotide_dict[x] for x in seq])
        sequences = np.stack(sequences)

        indices = []
        for pos in tx_positions:
            reads = indices_dict[str(pos)]
            if (self.max_reads is not None) and (len(reads) >= self.max_reads):
                indices.extend(np.random.choice(reads, self.max_reads, replace=False))
            else:
                indices.extend(reads)
        indices = np.unique(indices)

        signals = signals[indices]
        
        # Constructing signal - position mask
        
        position_mask = np.zeros((len(indices), len(tx_positions)), dtype=bool)

        read_positions = positions[indices]
        read_positions_lengths = pos_lengths[indices]
        for i in range(len(position_mask)):
            read_pos, read_pos_length = read_positions[i], read_positions_lengths[i]
            position_mask[i] = np.isin(tx_positions, read_pos[:read_pos_length])
        
        return torch.Tensor(signals), torch.LongTensor(labels),\
            torch.BoolTensor(position_mask), torch.LongTensor(sequences)


def collate_fn(batches):
    signals = torch.cat([batch[0] for batch in batches])
    signals_length = torch.LongTensor([len(batch[0]) for batch in batches])
    labels = torch.cat([batch[1] for batch in batches])
    positions_masks = [batch[2] for batch in batches]
    sequences = torch.cat([batch[3]for batch in batches])
    sequences_length = torch.LongTensor([len(batch[3]) for batch in batches])
    return signals, signals_length, labels, positions_masks, sequences, sequences_length
