from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List, Literal, Optional, Tuple
import json
import os
import re
from collections import Counter
from pathlib import Path

import pandas as pd
from tqdm import tqdm


# ---------------------------------------------------------------------------
# Fragment helpers
# ---------------------------------------------------------------------------

_NUMERIC_ANCHOR = re.compile(r"\[\d+\*\]")

def strip_numeric_anchors(fragment: str) -> str:
    """
    Remove bracketed numeric anchors like [8*], [16*].
    Example: "[8*]C(C)(C)C" -> "C(C)(C)C"
    """
    return _NUMERIC_ANCHOR.sub("", fragment)


# Keep both names for clarity/compat with your older code
normalize_fragment = strip_numeric_anchors  # alias


# ---------------------------------------------------------------------------
# Decomposition CSV (SMILES -> fragments)
# ---------------------------------------------------------------------------

def generate_decomposition_csv_from_json(
    json_path: str | Path,
    output_csv_path: str | Path = "decomposition_task.csv",
    instruction: str = "Decompose the SMILES into fragments.",
    smiles_key: str = "smiles",
    fragments_key: str = "fragments",
) -> pd.DataFrame:
    """
    Input JSON: list of dicts with keys {smiles, fragments}.
    Output CSV: columns {instruction, input, output}.
    """
    data_list = json.loads(Path(json_path).read_text(encoding="utf-8"))

    rows = []
    for entry in tqdm(data_list, desc="building decomposition CSV"):
        smiles = entry[smiles_key]
        fragments = entry[fragments_key]
        fragment_str = ", ".join(strip_numeric_anchors(f) for f in fragments)

        rows.append({
            "instruction": instruction,
            "input": f"SMILES: <SMILES>{smiles}</SMILES>",
            "output": f"Fragments: {fragment_str}",
        })

    df = pd.DataFrame(rows, columns=["instruction", "input", "output"])
    df.to_csv(output_csv_path, index=False)
    return df


# ---------------------------------------------------------------------------
# Neighbor-pair construction (1-hop by fragment multiset distance)
# ---------------------------------------------------------------------------

def is_one_hop_neighbor(frags_a: List[str], frags_b: List[str]) -> bool:
    """
    1-hop if:
      - same length and multiset diff = 2 (one replace), OR
      - len diff = 1 and multiset diff = 1 (single add/remove).
    Uses stripped anchors for comparison.
    """
    a = [strip_numeric_anchors(f) for f in frags_a]
    b = [strip_numeric_anchors(f) for f in frags_b]
    c_a, c_b = Counter(a), Counter(b)

    diff_count = sum((c_a - c_b).values()) + sum((c_b - c_a).values())
    len_diff = abs(len(a) - len(b))
    return (len_diff == 1 and diff_count == 1) or (len_diff == 0 and diff_count == 2)


def find_neighbors_in_chunk(
    records: List[Dict[str, Any]],
    max_neighbors_per_i: int = 10,
    min_len_fragments: int = 3,
) -> Iterator[Tuple[Dict[str, Any], Dict[str, Any]]]:
    """
    Yields (mol_i, mol_j) for 1-hop neighbors within a chunk.
    """
    valid_idx = [i for i, mol in enumerate(records) if len(mol.get("fragments", [])) >= min_len_fragments]
    for i in tqdm(valid_idx, desc="pairing", leave=False):
        mol_i = records[i]
        seen = 0
        for j in range(i + 1, len(records)):
            if seen >= max_neighbors_per_i:
                break
            mol_j = records[j]
            if is_one_hop_neighbor(mol_i["fragments"], mol_j["fragments"]):
                yield mol_i, mol_j
                seen += 1


def build_neighbor_pairs(
    input_json: str | Path,
    output_path: str | Path,
    chunk_size: int = 100_000,
    processes: int = 4,
    max_neighbors_per_i: int = 10,
    min_len_fragments: int = 3,
    output_format: Literal["jsonl", "json"] = "jsonl",
) -> None:
    """
    Build 1-hop neighbor pairs from a large molecule list, chunked and (optionally) parallelized.

    - Input file: JSON list of molecule dicts (must contain "fragments")
    - Output:
        * jsonl: one JSON pair per line: [mol_i, mol_j]
        * json: a single JSON array of pairs

    Notes:
      * No hard-coded paths; everything parameterized.
      * Safe for large inputs via chunking + streaming write.
    """
    import multiprocessing as mp  # local import to avoid issues on some platforms

    input_path = Path(input_json)
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    data = json.loads(input_path.read_text(encoding="utf-8"))

    # Split into chunks
    chunks: List[List[Dict[str, Any]]] = [
        data[i : i + chunk_size] for i in range(0, len(data), chunk_size)
    ]

    # Worker
    def _worker(chunk: List[Dict[str, Any]]) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
        return list(
            find_neighbors_in_chunk(
                chunk,
                max_neighbors_per_i=max_neighbors_per_i,
                min_len_fragments=min_len_fragments,
            )
        )

    # Map (parallel or serial)
    if processes and processes > 1:
        with mp.Pool(processes=processes) as pool:
            results = list(tqdm(pool.imap(_worker, chunks), total=len(chunks), desc="chunks"))
    else:
        results = [_worker(c) for c in tqdm(chunks, desc="chunks")]

    # Stream out
    if output_format == "jsonl":
        with output_path.open("w", encoding="utf-8") as fout:
            for pairs in results:
                for a, b in pairs:
                    fout.write(json.dumps([a, b]) + "\n")
    else:
        all_pairs: List[List[Dict[str, Any]]] = []
        for pairs in results:
            all_pairs.extend([ [a, b] for a, b in pairs ])
        output_path.write_text(json.dumps(all_pairs, indent=2), encoding="utf-8")


# ---------------------------------------------------------------------------
# Single-step modification detector (replace/add/remove), and QA builders
# ---------------------------------------------------------------------------

def canonicalize_fragment(fragment: str) -> str:
    """Alias: normalize fragment by stripping numeric anchors."""
    return strip_numeric_anchors(fragment)


def detect_single_step_modification(
    old_frags: List[str], new_frags: List[str]
) -> Tuple[Literal["replace","add","remove"], str, Optional[str]]:
    """
    Detect a single-step edit between old and new fragment lists (ignoring anchors).

    Returns:
      - ("replace", old_fragment, new_fragment)
      - ("add",    new_fragment, None)
      - ("remove", old_fragment, None)

    Raises ValueError if not exactly one edit (length diff >1 or >1 mismatch).
    """
    old_can = [canonicalize_fragment(f) for f in old_frags]
    new_can = [canonicalize_fragment(f) for f in new_frags]

    def leftover(a: List[str], b: List[str]) -> List[str]:
        tmp = a[:]
        for x in b:
            if x in tmp:
                tmp.remove(x)
        return tmp

    if abs(len(old_can) - len(new_can)) > 1:
        raise ValueError("Lists differ by more than 1; not a single-step edit.")

    if len(old_can) == len(new_can):
        # replace
        u_old = leftover(old_can, new_can)
        u_new = leftover(new_can, old_can)
        if len(u_old) != 1 or len(u_new) != 1:
            raise ValueError("Expected exactly one replace difference.")
        old_diff = u_old[0]
        new_diff = u_new[0]
        # Return canonical (anchors stripped)
        return ("replace", old_diff, new_diff)

    if len(new_can) == len(old_can) + 1:
        # add
        u_new = leftover(new_can, old_can)
        if len(u_new) != 1:
            raise ValueError("Expected exactly one unmatched new fragment (add).")
        return ("add", u_new[0], None)

    # remove
    u_old = leftover(old_can, new_can)
    if len(u_old) != 1:
        raise ValueError("Expected exactly one unmatched old fragment (remove).")
    return ("remove", u_old[0], None)


def _fmt_delta(a: float, b: float) -> Tuple[str, float]:
    """Return ('higher'|'lower', abs_delta)."""
    direction = "higher" if b > a else "lower"
    return direction, abs(b - a)


def generate_qa_pair(pair: List[Dict[str, Any]]) -> Tuple[str, str]:
    """
    Build (question, answer) for pair = [old_mol, new_mol].
    Expects keys: 'smiles', 'fragments', 'properties' (QED, LogP, Molecular Weight).
    """
    old, new = pair
    old_smi, new_smi = old["smiles"], new["smiles"]

    step = detect_single_step_modification(old["fragments"], new["fragments"])
    if step[0] == "replace":
        step_str = f"Replace {step[1]} with {step[2]}"
    elif step[0] == "add":
        step_str = f"Add {step[1]}"
    else:
        step_str = f"Remove {step[1]}"

    fragments_text = str([canonicalize_fragment(f) for f in old["fragments"]])

    # Directions (optional if you want to include deltas)
    q = (
        "Given the intermediate molecule SMILES "
        f"<SMILES>{old_smi}</SMILES>, which is composed of fragments {fragments_text}. "
        "Propose a single replace, add or remove step on fragment level that makes the "
        "new molecule's properties change accordingly."
    )
    a = f"{step_str} to form <SMILES>{new_smi}</SMILES>."
    return q, a


def generate_qa_pair_for_modification(pair: List[Dict[str, Any]]) -> Tuple[str, str]:
    """
    Variant: Q asks for the modified SMILES given the fragment edit;
    A returns the target SMILES.
    """
    old, new = pair
    old_smi, new_smi = old["smiles"], new["smiles"]

    step = detect_single_step_modification(old["fragments"], new["fragments"])
    if step[0] == "replace":
        step_str = f"Replace {step[1]} with {step[2]}"
    elif step[0] == "add":
        step_str = f"Add {step[1]}"
    else:
        step_str = f"Remove {step[1]}"

    fragments_text = str([canonicalize_fragment(f) for f in old["fragments"]])
    # lowercase the leading verb in the inline step
    step_str_lower = step_str[0].lower() + step_str[1:]

    q = (
        f"Given the SMILES <SMILES>{old_smi}</SMILES>, which is composed of fragments {fragments_text}. "
        f"Give me the modified SMILES if I {step_str_lower}."
    )
    a = f"<SMILES>{new_smi}</SMILES>"
    return q, a


def build_qa_dataframe(pairs: Iterable[List[Dict[str, Any]]]) -> pd.DataFrame:
    """Return a DataFrame with columns: id, question, answer."""
    rows = []
    for idx, pair in enumerate(pairs):
        try:
            q, a = generate_qa_pair(pair)
            rows.append({"id": idx, "question": q, "answer": a})
        except ValueError:
            # skip non-1-hop or malformed pairs
            continue
    return pd.DataFrame(rows, columns=["id", "question", "answer"])


def build_qa_dataframe_for_modification(pairs: Iterable[List[Dict[str, Any]]]) -> pd.DataFrame:
    """Return a DataFrame with columns: id, question, answer."""
    rows = []
    for idx, pair in enumerate(pairs):
        try:
            q, a = generate_qa_pair_for_modification(pair)
            rows.append({"id": idx, "question": q, "answer": a})
        except ValueError:
            continue
    return pd.DataFrame(rows, columns=["id", "question", "answer"])


def write_qa_csv(
    neighbor_pairs_path: str | Path,
    out_csv_path: str | Path,
    task: Literal["directional", "modification"] = "directional",
) -> pd.DataFrame:
    """
    Load neighbor pairs and write a QA CSV.
    - 'directional': question asks for a beneficial edit; answer proposes step + SMILES.
    - 'modification': question gives an edit; answer returns SMILES.
    """
    # pairs file can be JSONL (one pair per line) or a JSON array
    p = Path(neighbor_pairs_path)
    pairs: List[List[Dict[str, Any]]] = []
    if p.suffix.lower() == ".jsonl":
        with p.open("r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    pairs.append(json.loads(line))
    else:
        pairs = json.loads(p.read_text(encoding="utf-8"))

    if task == "directional":
        df = build_qa_dataframe(pairs)
    else:
        df = build_qa_dataframe_for_modification(pairs)

    df.to_csv(out_csv_path, index=False)
    return df
