import torch

import pandas as pd

import numpy as np

from pyfaidx import Fasta

from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl

from typing import List, Dict, Union, Optional

import random

import os

import requests

from tqdm import tqdm







DEFAULT_BED_URL = "https://example.com/human-sequences.bed"

DEFAULT_FASTA_URL = "https://example.com/hg38.ml.fa"



def download_file(url: str, dest_path: str):

    """Downloads a file from a URL to a destination path."""

    if os.path.exists(dest_path):

        print(f"File already exists: {dest_path}")

        return



    print(f"Downloading {url} to {dest_path}...")

    os.makedirs(os.path.dirname(dest_path), exist_ok=True)



    response = requests.get(url, stream=True)

    response.raise_for_status()



    total_size = int(response.headers.get('content-length', 0))

    block_size = 1024

    progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)



    with open(dest_path, 'wb') as f:

        for data in response.iter_content(block_size):

            progress_bar.update(len(data))

            f.write(data)

    progress_bar.close()

    print(f"Downloaded {dest_path}")





NUCLEOTIDE_VOCAB = {

    "A": 0, "C": 1, "G": 2, "T": 3,

    "N": 4,

}

SPECIAL_TOKENS = {

    "[UNK]": 5,

    "[CLS]": 6,

    "[SEP]": 7,

    "[PAD]": 8,

    "[MASK]": 9,

}

VOCAB = {**NUCLEOTIDE_VOCAB, **SPECIAL_TOKENS}

VOCAB_SIZE = len(VOCAB)

UNK_ID = VOCAB["[UNK]"]

PAD_ID = VOCAB["[PAD]"]

MASK_ID = VOCAB["[MASK]"]





def seed_worker(worker_id: int) -> None:

    worker_seed = torch.initial_seed() % 2**32

    np.random.seed(worker_seed)

    random.seed(worker_seed)



def a(ids: List[int]) -> str:

    """Converts a list of token IDs back to a DNA string."""

    id_to_token = {v: k for k, v in VOCAB.items()}

    return "".join([id_to_token.get(i, "?") for i in ids])



def tokenize(seq: str, add_special_tokens: bool = False) -> List[int]:

    """Converts a DNA sequence string into a list of token IDs."""

    token_ids = [NUCLEOTIDE_VOCAB.get(base, UNK_ID) for base in seq.upper()]

    if add_special_tokens:

        return [VOCAB["[CLS]"]] + token_ids + [VOCAB["[SEP]"]]

    return token_ids



class GenomicDataset(Dataset):

    """
    Dataset for loading genomic sequences from a FASTA file based on intervals
    defined in a BED file. Automatically downloads files if not found.

    Args:
        bed_file (str): Path to the BED file containing genomic intervals.
        fasta_file (str): Path to the FASTA file containing the reference genome.
        min_seq_len (int): Minimum sequence length for sampling.
        max_seq_len (int): Maximum sequence length for sampling.
    """

    def __init__(self, bed_file: str, fasta_file: str, min_seq_len: int, max_seq_len: int):





        if not os.path.exists(bed_file):

            print(f"BED file not found at {bed_file}. Attempting to download...")

            download_file(DEFAULT_BED_URL, bed_file)

        if not os.path.exists(fasta_file):

            print(f"FASTA file not found at {fasta_file}. Attempting to download...")

            download_file(DEFAULT_FASTA_URL, fasta_file)



        self.bed_file = bed_file

        self.fasta_file = fasta_file

        self.min_seq_len = min_seq_len

        self.max_seq_len = max_seq_len





        self.intervals = pd.read_csv(

            bed_file, sep="\t", header=None,

            usecols=[0, 1, 2], names=["chrom", "start", "end"],

            dtype={"chrom": str, "start": int, "end": int}

        )

        self.genome = Fasta(fasta_file, sequence_always_upper=True)





        self.intervals = self.intervals[self.intervals["end"] - self.intervals["start"] >= min_seq_len]

        if len(self.intervals) == 0:

            raise ValueError("No intervals in the BED file are large enough for the specified min_seq_len.")



    def __len__(self) -> int:

        return len(self.intervals)



    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        interval = self.intervals.iloc[idx]

        chrom, start, end = interval["chrom"], interval["start"], interval["end"]





        if self.min_seq_len == self.max_seq_len:

            seq_len = self.min_seq_len

        else:

            seq_len = random.randint(self.min_seq_len, self.max_seq_len)





        max_start_pos = end - start - seq_len

        if max_start_pos <= 0:





            sample_start = start

            seq_len = end - start

        else:

            sample_start = start + random.randint(0, max_start_pos)





        sequence_str = self.genome[chrom][sample_start : sample_start + seq_len].seq

        token_ids = tokenize(sequence_str, add_special_tokens=False)



        return {"input_ids": torch.tensor(token_ids, dtype=torch.long)}





class DataCollatorForMLM:

    """
    Data collator for Masked Language Modeling. Handles batching, padding,
    and masking of tokens.
    """

    def __init__(self, mlm_probability: float = 0.15, add_special_tokens: bool = True):

        self.mlm_probability = mlm_probability

        self.add_special_tokens = add_special_tokens



    def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:



        if self.add_special_tokens:

            for item in batch:

                item["input_ids"] = torch.cat([

                    torch.tensor([VOCAB["[CLS]"]], dtype=torch.long),

                    item["input_ids"],

                    torch.tensor([VOCAB["[SEP]"]], dtype=torch.long),

                ])





        input_ids = [item["input_ids"] for item in batch]

        padded_inputs = torch.nn.utils.rnn.pad_sequence(

            input_ids, batch_first=True, padding_value=PAD_ID

        )





        labels = padded_inputs.clone()

        probability_matrix = torch.full(labels.shape, self.mlm_probability)





        special_tokens_mask = (

            (padded_inputs == VOCAB["[CLS]"]) |

            (padded_inputs == VOCAB["[SEP]"]) |

            (padded_inputs == VOCAB["[PAD]"])

        )

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()



        labels[~masked_indices] = -100





        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices

        padded_inputs[indices_replaced] = MASK_ID





        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced

        random_words = torch.randint(len(NUCLEOTIDE_VOCAB), labels.shape, dtype=torch.long)

        padded_inputs[indices_random] = random_words[indices_random]







        return {

            "input_ids": padded_inputs,

            "labels": labels,

        }





class PretrainingDataModule(pl.LightningDataModule):

    """
    PyTorch Lightning DataModule for pretraining on genomic sequences.
    Splits the BED file intervals into train/val sets.
    """

    def __init__(

        self,

        bed_file: str,

        fasta_file: str,

        min_seq_len: int = 512,

        max_seq_len: int = 2048,

        mlm_probability: float = 0.15,

        batch_size: int = 32,

        num_workers: int = 4,

        val_split: float = 0.05,

        seed: int = 42,

    ):

        super().__init__()

        self.bed_file = bed_file

        self.fasta_file = fasta_file

        self.min_seq_len = min_seq_len

        self.max_seq_len = max_seq_len

        self.mlm_probability = mlm_probability

        self.batch_size = batch_size

        self.num_workers = num_workers

        self.val_split = val_split

        self.seed = seed

        self._generator = torch.Generator().manual_seed(seed)



        self.train_dataset = None

        self.val_dataset = None



    def setup(self, stage: Optional[str] = None):

        if stage == 'fit' or stage is None:



            full_dataset = GenomicDataset(

                bed_file=self.bed_file,

                fasta_file=self.fasta_file,

                min_seq_len=self.min_seq_len,

                max_seq_len=self.max_seq_len,

            )





            total_size = len(full_dataset)

            val_size = int(total_size * self.val_split)

            train_size = total_size - val_size



            self.train_dataset, self.val_dataset = torch.utils.data.random_split(

                full_dataset, [train_size, val_size],

                generator=torch.Generator().manual_seed(42)

            )



    def train_dataloader(self):

        collator = DataCollatorForMLM(mlm_probability=self.mlm_probability)

        return DataLoader(

            self.train_dataset,

            batch_size=self.batch_size,

            shuffle=True,

            num_workers=self.num_workers,

            collate_fn=collator,

            pin_memory=True,

            worker_init_fn=seed_worker,

            generator=self._generator,

        )



    def val_dataloader(self):

        collator = DataCollatorForMLM(mlm_probability=self.mlm_probability)

        return DataLoader(

            self.val_dataset,

            batch_size=self.batch_size,

            shuffle=False,

            num_workers=self.num_workers,

            collate_fn=collator,

            pin_memory=True,

            worker_init_fn=seed_worker,

            generator=self._generator,

        )

