import glob
import os
import urllib.request
import xml.etree.ElementTree as ET
import zipfile
from dataclasses import dataclass
from typing import Any

import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from research.wsl_ece.metric.sized_dataset import SizedDataset

# Special tokens used for entity marking
E1_OPEN_TOKEN = "[E1]"
E1_CLOSE_TOKEN = "[/E1]"
E2_OPEN_TOKEN = "[E2]"
E2_CLOSE_TOKEN = "[/E2]"
DOM_DRUGBANK_TOKEN = "[DOM_DRUGBANK]"
DOM_MEDLINE_TOKEN = "[DOM_MEDLINE]"
SPECIAL_TOKENS = [E1_OPEN_TOKEN, E1_CLOSE_TOKEN, E2_OPEN_TOKEN, E2_CLOSE_TOKEN, DOM_DRUGBANK_TOKEN, DOM_MEDLINE_TOKEN]


@dataclass
class Entity:
    eid: str
    etype: str
    spans: list[tuple[int, int]]  # inclusive offsets (start,end)
    text: str


@dataclass
class PairExample:
    doc_id: str
    sent_id: str
    pair_id: str
    domain: str  # 'drugbank' or 'medline' or 'unknown'
    sent_text: str
    e1: Entity
    e2: Entity
    label_id: int  # 0..4


def _parse_spans(offset_str: str) -> list[tuple[int, int]]:
    spans = []
    for part in offset_str.split(";"):
        a, b = part.split("-")
        spans.append((int(a), int(b)))
    # sort by start
    spans.sort(key=lambda x: x[0])
    return spans


DDI_TYPES = ["NONE", "mechanism", "effect", "advise", "int"]
TYPE_TO_ID = {t: i for i, t in enumerate(DDI_TYPES)}
ID_TO_TYPE = {i: t for t, i in TYPE_TO_ID.items()}


def _infer_domain(doc_id: str) -> str:
    # e.g., "DDI-DrugBank.*" or "DDI-MedLine.*"
    lower = doc_id.lower()
    if "drugbank" in lower:
        return "drugbank"
    if "medline" in lower:
        return "medline"
    return "unknown"


DDI_TYPES = ["NONE", "mechanism", "effect", "advise", "int"]
TYPE_TO_ID = {t: i for i, t in enumerate(DDI_TYPES)}
ID_TO_TYPE = {i: t for t, i in TYPE_TO_ID.items()}


def insert_markers(
    text: str,
    spans1: list[tuple[int, int]],
    spans2: list[tuple[int, int]],
    e1_open=E1_OPEN_TOKEN,
    e1_close=E1_CLOSE_TOKEN,
    e2_open=E2_OPEN_TOKEN,
    e2_close=E2_CLOSE_TOKEN,
) -> str:
    """Insert markers around the two (possibly multi-span) mentions.

    Strategy: place opening marker at first-span start and closing at last-span end (inclusive+1).
    Insert later indices first to avoid shifting earlier positions.
    """

    def bounds(spans: list[tuple[int, int]]):
        starts = [s for s, _ in spans]
        ends = [e for _, e in spans]
        return min(starts), max(ends) + 1  # end is exclusive for slicing

    a_start, a_end = bounds(spans1)
    b_start, b_end = bounds(spans2)

    # Create insertion plan: (index, string_to_insert)
    inserts = []
    inserts.append((a_start, e1_open))
    inserts.append((a_end, e1_close))
    inserts.append((b_start, e2_open))
    inserts.append((b_end, e2_close))

    # Sort by index descending so earlier indices unaffected
    inserts.sort(key=lambda x: x[0], reverse=True)

    out = text
    for idx, token in inserts:
        out = out[:idx] + token + out[idx:]
    return out


def load_xml_dir(xml_dir: str) -> list[PairExample]:
    """Recursively load all *.xml files in xml_dir and convert to pair examples."""
    examples: list[PairExample] = []
    xml_files = glob.glob(os.path.join(xml_dir, "**", "*.xml"), recursive=True)
    if not xml_files:
        raise FileNotFoundError(f"No XML files found under {xml_dir}")
    for xf in tqdm(sorted(xml_files), desc="Loading XML files from " + xml_dir):
        tree = ET.parse(xf)
        root = tree.getroot()

        # Handle both cases: multiple documents under root, or root IS the document
        documents = root.findall("document")
        if not documents and root.tag == "document":
            documents = [root]

        for doc in documents:
            doc_id = doc.get("id")
            if not doc_id:
                continue  # Skip documents without ID
            domain = _infer_domain(doc_id)
            for sent in doc.findall("sentence"):
                sid = sent.get("id")
                if not sid:
                    continue  # Skip sentences without ID
                stext = sent.get("text") or ""
                # entities
                ents: dict[str, Entity] = {}
                for ent in sent.findall("entity"):
                    eid = ent.get("id")
                    etype = ent.get("type")
                    char_offset = ent.get("charOffset")
                    if not eid or not etype or not char_offset:
                        continue  # Skip entities with missing required attributes
                    spans = _parse_spans(char_offset)
                    text = ent.get("text") or ""
                    ents[eid] = Entity(eid=eid, etype=etype, spans=spans, text=text)
                # pairs
                for pr in sent.findall("pair"):
                    pid = pr.get("id")
                    e1id = pr.get("e1")
                    e2id = pr.get("e2")
                    ddi = pr.get("ddi")
                    rtype = pr.get("type")  # None if ddi="false"
                    if not pid or not e1id or not e2id:
                        continue  # Skip pairs with missing required attributes
                    if ddi == "true" and rtype:
                        label = TYPE_TO_ID.get(rtype, 0)
                    else:
                        label = 0  # NONE
                    if e1id not in ents or e2id not in ents:  # safety
                        continue
                    examples.append(
                        PairExample(
                            doc_id=doc_id,
                            sent_id=sid,
                            pair_id=pid,
                            domain=domain,
                            sent_text=stext,
                            e1=ents[e1id],
                            e2=ents[e2id],
                            label_id=label,
                        )
                    )
    assert len(examples) > 0, f"No examples found in XML files under {xml_dir}"
    return examples


@dataclass
class EncodedExample:
    """
    A dataclass representing an encoded DDI example with tokenized text and metadata.
    """

    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    token_type_ids: torch.Tensor
    label: int
    domain: str
    meta: dict


class BinarizedDDI2013(SizedDataset):
    """
    A binarized DDI (Drug-Drug Interaction) dataset that loads examples from XML directories.

    This dataset converts multi-class DDI classification to binary classification,
    where label 0 represents "no interaction" and label 1 represents "interaction".
    The dataset structure is specialized for huggingface's BERT-based models.
    The class assumes the dataset is from:
    https://github.com/isegura/DDICorpus/raw/refs/heads/master/DDICorpus-2013.zip
    """

    def __init__(
        self,
        root: str,
        tokenizer=None,
        max_len: int = 512,
        use_domain_token: bool = False,
        neg_ratio: float | None = None,
        xml_dirs: list[str] | None = None,
        train: bool = True,
        download: bool = False,
    ):
        if tokenizer is None:
            tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
            # Add special tokens for entity marking
            tokenizer.add_special_tokens({"additional_special_tokens": SPECIAL_TOKENS})

        # Download dataset if needed
        if download:
            self._download_and_extract(root)

        # Default XML directories if not specified
        if xml_dirs is None:
            if train:
                xml_dirs = [f"{root}/DDICorpus/Train"]
            else:
                xml_dirs = [f"{root}/DDICorpus/Test/Test for DDI Extraction task"]
        self.examples = self._build_examples(xml_dirs, tokenizer, max_len, use_domain_token, neg_ratio)

    def _build_examples(
        self,
        xml_dirs: list[str],
        tokenizer,
        max_len: int,
        use_domain_token: bool,
        neg_ratio: float | None = None,
    ) -> list[EncodedExample]:
        """Build encoded examples from XML directories (migrated from build_examples function)."""
        import random

        # Load and merge examples from multiple directories
        all_pairs = []
        for d in xml_dirs:
            all_pairs.extend(load_xml_dir(d))

        # Optional negative undersampling
        if neg_ratio is not None:
            pos = [ex for ex in all_pairs if ex.label_id != 0]
            neg = [ex for ex in all_pairs if ex.label_id == 0]
            keep_neg = int(len(pos) * neg_ratio)
            random.shuffle(neg)
            all_pairs = pos + neg[:keep_neg]
            random.shuffle(all_pairs)

        # Build encoded examples
        enc_list: list[EncodedExample] = []
        for ex in all_pairs:
            marked = insert_markers(ex.sent_text, ex.e1.spans, ex.e2.spans)
            prefix = ""
            if use_domain_token:
                if ex.domain == "drugbank":
                    prefix = DOM_DRUGBANK_TOKEN + " "
                elif ex.domain == "medline":
                    prefix = DOM_MEDLINE_TOKEN + " "
            text_in = (prefix + marked).strip()

            # Tokenize
            toks = tokenizer(text_in, truncation=True, max_length=max_len, padding=False, return_tensors=None)
            input_ids = torch.tensor(toks["input_ids"], dtype=torch.long)
            attention_mask = torch.tensor(toks["attention_mask"], dtype=torch.long)
            token_type_ids = torch.tensor(toks.get("token_type_ids", [0] * len(toks["input_ids"])), dtype=torch.long)

            label = ex.label_id
            label = 0 if ex.label_id == 0 else 1

            enc_list.append(
                EncodedExample(
                    input_ids,
                    attention_mask,
                    token_type_ids,
                    label,
                    ex.domain,
                    {
                        "doc_id": ex.doc_id,
                        "sent_id": ex.sent_id,
                        "pair_id": ex.pair_id,
                        "e1": ex.e1.text,
                        "e2": ex.e2.text,
                    },
                )
            )
        return enc_list

    def _download_and_extract(self, root: str) -> None:
        """Download and extract the DDI Corpus 2013 dataset."""
        ddi_url = "https://github.com/isegura/DDICorpus/raw/refs/heads/master/DDICorpus-2013.zip"
        zip_path = os.path.join(root, "DDICorpus-2013.zip")

        # Create root directory if it doesn't exist
        os.makedirs(root, exist_ok=True)

        # Check if dataset is already extracted
        train_dir = os.path.join(root, "DDICorpus/Train")
        test_dir = os.path.join(root, "DDICorpus/Test/Test for DDI Extraction task")
        if os.path.exists(train_dir) and os.path.exists(test_dir):
            print(f"DDI Corpus dataset already exists at {root}")
            return

        # Download the dataset
        print(f"Downloading DDI Corpus 2013 from {ddi_url}...")
        try:
            urllib.request.urlretrieve(ddi_url, zip_path)
            print(f"Downloaded to {zip_path}")
        except Exception as e:
            raise RuntimeError(f"Failed to download DDI Corpus dataset: {e}")

        # Extract the zip file
        print("Extracting dataset...")
        try:
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(root)
            print(f"Extracted to {root}")

            # Remove the zip file to save space
            os.remove(zip_path)
            print("Cleaned up zip file")

        except Exception as e:
            raise RuntimeError(f"Failed to extract DDI Corpus dataset: {e}")

    def __getitem__(self, index) -> tuple[Any, Any]:
        """Get an item from the dataset. Returns (features, label) for compatibility with other datasets."""
        ex = self.examples[index]
        # For compatibility with other dataset classes, return features dict and label
        features = {
            "input_ids": ex.input_ids,
            "attention_mask": ex.attention_mask,
            "token_type_ids": ex.token_type_ids,
            "domain": ex.domain,
            "meta": ex.meta,
        }
        return features, ex.label

    def __len__(self) -> int:
        return len(self.examples)
