from __future__ import annotations
from typing import Dict, Any, List, Optional
import os, json, re, ast
from dataclasses import dataclass

import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig

# --- RDKit property utils for rewards ---
from rdkit import Chem
from rdkit.Chem import QED, Crippen, Descriptors
from collections import Counter, defaultdict

# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------

def _bf16_supported() -> bool:
    return torch.cuda.is_available() and torch.cuda.is_bf16_supported()

def _ensure_pad_eos(tokenizer):
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

def _add_special_tokens(tokenizer, specials: List[str]):
    if specials:
        tokenizer.add_special_tokens({"additional_special_tokens": specials})

def _load_dataset_from_cfg(cfg: Dict[str, Any], purpose: str) -> Dataset:
    """
    Load a dataset either from HF Hub (data.dataset_name / split) or from local file (data.train_file).
    `purpose` is one of {"sft", "grpo"} to pick the right column names.
    Expect columns:
      - SFT:   input, output
      - GRPO:  prompt (or input -> we format)
    """
    ds = None
    ds_name = cfg.get("data", {}).get("dataset_name")
    split = cfg.get("data", {}).get("split", "train")
    path = cfg.get("data", {}).get("train_file")

    if ds_name:
        ds = load_dataset(ds_name, split=split)
    elif path:
        # Accept json/jsonl/csv
        if path.endswith(".jsonl"):
            rows = [json.loads(l) for l in open(path, "r", encoding="utf-8") if l.strip()]
            ds = Dataset.from_list(rows)
        elif path.endswith(".json"):
            rows = json.load(open(path, "r", encoding="utf-8"))
            ds = Dataset.from_list(rows)
        elif path.endswith(".csv"):
            ds = Dataset.from_pandas(__import__("pandas").read_csv(path))
        else:
            raise ValueError(f"Unsupported data file: {path}")
    else:
        raise ValueError("Provide either data.dataset_name or data.train_file in config.yaml")

    # Basic checks / formatting
    if purpose == "sft":
        missing = [c for c in ("input", "output") if c not in ds.column_names]
        if missing:
            raise ValueError(f"SFT dataset must have columns 'input' and 'output', missing: {missing}")
    elif purpose == "grpo":
        # Accept 'prompt' or build it from 'input'
        if "prompt" not in ds.column_names:
            if "input" in ds.column_names:
                # map input -> prompt; output is not used in GRPO
                ds = ds.map(lambda ex: {"prompt": ex["input"]})
            else:
                raise ValueError("GRPO dataset must have 'prompt' or 'input' column.")
    return ds

def _build_lora_config(cfg: Dict[str, Any]) -> LoraConfig:
    lora = cfg.get("lora", {})
    return LoraConfig(
        r          = lora.get("r", 32),
        lora_alpha = lora.get("alpha", 32),
        lora_dropout = lora.get("dropout", 0.05),
        bias       = lora.get("bias", "none"),
        task_type  = TaskType.CAUSAL_LM,
        target_modules = lora.get("target_modules", ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]),
    )

def _build_tokenizer(cfg: Dict[str, Any], *, pretrained_path: Optional[str] = None):
    base = pretrained_path or cfg["model"]["base"]
    tok = AutoTokenizer.from_pretrained(base, trust_remote_code=True)
    _ensure_pad_eos(tok)
    _add_special_tokens(tok, cfg["model"].get("special_tokens", []))
    return tok

def _build_model_for_sft(cfg: Dict[str, Any], tokenizer) -> torch.nn.Module:
    base = cfg["model"]["base"]
    model = AutoModelForCausalLM.from_pretrained(
        base,
        trust_remote_code=True,
        load_in_8bit=cfg.get("model", {}).get("load_in_8bit", True),
        torch_dtype=torch.bfloat16 if _bf16_supported() else torch.float16,
        device_map="auto",
    )
    # Must resize embeddings after adding special tokens
    model.resize_token_embeddings(len(tokenizer))

    # Attach LoRA
    peft_cfg = _build_lora_config(cfg)
    model = get_peft_model(model, peft_cfg)
    return model

def _build_model_for_grpo(cfg: Dict[str, Any], tokenizer) -> torch.nn.Module:
    """
    Build model for GRPO:
    - If grpo.init_ckpt is given: start from that checkpoint (expects LoRA or full).
    - Else: start from base model (+LoRA).
    - If grpo.lora_adapter is provided: load it on top and set trainable.
    """
    init_ckpt = cfg.get("grpo", {}).get("init_ckpt")
    base_name = cfg["model"]["base"]

    if init_ckpt:
        # Load the *base* then attach LoRA adapter from init_ckpt if it's a PEFT adapter
        base_model = AutoModelForCausalLM.from_pretrained(
            base_name,
            trust_remote_code=True,
            load_in_8bit=cfg.get("model", {}).get("load_in_8bit", True),
            torch_dtype=torch.bfloat16 if _bf16_supported() else torch.float16,
            device_map="auto",
        )
        base_model.resize_token_embeddings(len(tokenizer))
        model = PeftModel.from_pretrained(base_model, init_ckpt, is_trainable=True)
    else:
        # Fresh LoRA on top of base
        base_model = AutoModelForCausalLM.from_pretrained(
            base_name,
            trust_remote_code=True,
            load_in_8bit=cfg.get("model", {}).get("load_in_8bit", True),
            torch_dtype=torch.bfloat16 if _bf16_supported() else torch.float16,
            device_map="auto",
        )
        base_model.resize_token_embeddings(len(tokenizer))
        peft_cfg = _build_lora_config(cfg)
        model = get_peft_model(base_model, peft_cfg)

    # Optionally load/stack another LoRA
    lora_adapter = cfg.get("grpo", {}).get("lora_adapter")
    if lora_adapter:
        model = PeftModel.from_pretrained(model, lora_adapter, is_trainable=True)
    return model

# -----------------------------------------------------------------------------
# SFT
# -----------------------------------------------------------------------------

def _sft_formatting_func(system_prompt: str):
    def fmt(example):
        # "[Round 0]" matches your original formatting
        return f"[Round 0]\nHuman: {system_prompt}{example['input']}\nAssistant: {example['output']}"
    return fmt

def train_sft(cfg: Dict[str, Any]) -> None:
    """
    Supervised fine-tuning with LoRA. Expects dataset with columns: input, output.
    """
    system_prompt = cfg.get("model", {}).get(
        "system_prompt",
        "You love and excel at editing SMILES strings to make original SMILES meet the required numeric properties.\n",
    )

    # Tokenizer
    tokenizer = _build_tokenizer(cfg)

    # Dataset
    ds = _load_dataset_from_cfg(cfg, purpose="sft")

    # Model (+LoRA)
    model = _build_model_for_sft(cfg, tokenizer)

    # SFT args
    out_dir = cfg.get("sft", {}).get("output_dir", os.path.join(cfg["project"]["exp_dir"], "sft"))
    sft_args = SFTConfig(
        output_dir=out_dir,
        per_device_train_batch_size=cfg["sft"].get("batch_size", 4),
        gradient_accumulation_steps=cfg["sft"].get("grad_accum", 4),
        num_train_epochs=cfg["sft"].get("epochs", 1),
        learning_rate=cfg["sft"].get("lr", 5e-5),
        save_strategy="steps",
        save_steps=cfg["sft"].get("save_steps", 500),
        save_total_limit=cfg["sft"].get("save_total_limit", 2),
        logging_steps=cfg["sft"].get("logging_steps", 50),
        report_to=cfg.get("logging", {}).get("report_to", "none"),
        logging_dir=cfg.get("logging", {}).get("dir", "./logs-sft"),
        bf16=_bf16_supported(),
        fp16=not _bf16_supported(),
        max_seq_length=cfg["sft"].get("max_seq_length", 1024),
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_args,
        train_dataset=ds,
        peft_config=None,  # already wrapped
        formatting_func=lambda ex: _sft_formatting_func(system_prompt)(ex),
        processing_class=tokenizer,
    )

    trainer.train()
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)

# -----------------------------------------------------------------------------
# GRPO reward helpers
# -----------------------------------------------------------------------------

def _canonical_smiles(s: str) -> Optional[str]:
    m = Chem.MolFromSmiles(s)
    return Chem.MolToSmiles(m, canonical=True) if m else None

def _get_element_count(smiles: str) -> Optional[Counter]:
    m = Chem.MolFromSmiles(smiles)
    if not m:
        return None
    return Counter([a.GetSymbol() for a in m.GetAtoms()])

def _get_clean_fragments(smiles: str) -> List[str]:
    """BRICS fragments without attachment point markers like [4*]."""
    try:
        from rdkit.Chem import BRICS
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return []
        mol_broken = BRICS.BreakBRICSBonds(mol)
        frags = Chem.GetMolFrags(mol_broken, asMols=True, sanitizeFrags=True)
        out = []
        for frag in frags:
            ed = Chem.EditableMol(frag)
            to_remove = [a.GetIdx() for a in frag.GetAtoms() if a.GetSymbol() == "*" and a.GetAtomicNum() == 0]
            for idx in sorted(to_remove, reverse=True):
                ed.RemoveAtom(idx)
            cleaned = ed.GetMol()
            Chem.SanitizeMol(cleaned)
            out.append(Chem.MolToSmiles(cleaned, isomericSmiles=True))
        return out
    except Exception:
        return []

def _extract_fragments_from_prompt(prompt: str) -> List[str]:
    """
    Pull fragments list from the prompt pattern:
      "... fragments ['frag1', 'frag2', ...]. Propose ..."
    """
    try:
        m = re.search(r"fragments\s*\[(.*?)\]\.", prompt, re.DOTALL)
        if not m:
            return []
        lst = "[" + m.group(1).strip() + "]"
        out = ast.literal_eval(lst)
        return out if isinstance(out, list) else []
    except Exception:
        return []

# -----------------------------------------------------------------------------
# GRPO rewards (cleaned)
# -----------------------------------------------------------------------------

def strict_format_reward(completions: List[str], **kwargs) -> List[float]:
    """
    +1.0 if the completion matches:
      'Replace ... with ... to form <SMILES>...</SMILES>.'   OR
      'Add ... to form <SMILES>...</SMILES>.'                OR
      'Remove ... to form <SMILES>...</SMILES>.'
    else -10.0
    """
    pat = re.compile(
        r"^(Replace|Add|Remove)\s+(.+?)(?:\s+with\s+(.+?))?\s+to\s+form\s+<SMILES>(.+)</SMILES>\.$",
        re.DOTALL,
    )
    rewards = []
    for text in completions:
        m = pat.fullmatch(text.strip())
        rewards.append(1.0 if m else -10.0)
    return rewards

_completion_counter = defaultdict(int)
_fragment_counter = defaultdict(int)

def repetition_penalty_reward(completions: List[str], **kwargs) -> List[float]:
    """
    Penalize repeated final SMILES and overused replacement fragments.
    """
    r = []
    for text in completions:
        penalty = 0.0
        # SMILES repetition / invalidity
        m = re.search(r"<SMILES>(.*?)</SMILES>", text)
        if m:
            smi = m.group(1)
            can = _canonical_smiles(smi)
            if can:
                count = _completion_counter[can]
                _completion_counter[can] += 1
                penalty += -0.1 * count
            else:
                penalty += -1.0  # invalid SMILES
        # Replacement fragment repetition
        m2 = re.match(r"^(Replace|Add)\s+(.+?)(?:\s+with\s+(.+?))?\s+to\s+form\s+<SMILES>", text.strip())
        if m2:
            repl = (m2.group(3) or "").strip()
            if repl:
                can_frag = _canonical_smiles(repl)
                if can_frag:
                    count_f = _fragment_counter[can_frag]
                    _fragment_counter[can_frag] += 1
                    penalty += -0.05 * count_f
                else:
                    penalty += -0.5  # invalid fragment
        r.append(penalty)
    return r

def multi_property_reward(completions: List[str], **kwargs) -> List[float]:
    """
    Reward ∈ (-∞, 1], 1 - (weighted, direction-aware normalized error),
    using targets parsed from the prompt:
        <QED>x</QED> (higher|lower), <LogP>y</LogP> (higher|lower), <MW>z</MW> (higher|lower)
    """
    prompts: List[str] = kwargs["prompts"]
    out = []

    for comp, prompt in zip(completions, prompts):
        in_m = re.search(r"<SMILES>(.*?)</SMILES>", prompt)
        out_m = re.search(r"<SMILES>(.*?)</SMILES>", comp)
        qed_m = re.search(r"<QED>(-?\d+\.\d+)</QED>\s*(higher|lower)", prompt)
        logp_m = re.search(r"<LogP>(-?\d+\.\d+)</LogP>\s*(higher|lower)", prompt)
        mw_m   = re.search(r"<MW>(-?\d+\.\d+)</MW>\s*(higher|lower)", prompt)

        if not (in_m and out_m and qed_m and logp_m and mw_m):
            out.append(-10.0); continue

        s_in, s_out = in_m.group(1), out_m.group(1)
        m_in, m_out = Chem.MolFromSmiles(s_in), Chem.MolFromSmiles(s_out)
        if not (m_in and m_out):
            out.append(-10.0); continue

        # Originals
        qed_in  = QED.qed(m_in)
        logp_in = Crippen.MolLogP(m_in)
        mw_in   = Descriptors.ExactMolWt(m_in)
        # Generated
        qed_out  = QED.qed(m_out)
        logp_out = Crippen.MolLogP(m_out)
        mw_out   = Descriptors.ExactMolWt(m_out)

        # Parse targets from deltas+direction
        def _parse(m, base):
            delta = float(m.group(1)); d = 1 if m.group(2) == "higher" else -1
            return base + d * delta, d

        qed_t, d_qed   = _parse(qed_m,  qed_in)
        logp_t, d_logp = _parse(logp_m, logp_in)
        mw_t, d_mw     = _parse(mw_m,   mw_in)

        # clamp QED target to [0.01, 0.99]
        qed_t = max(0.01, min(0.99, qed_t))

        # normalized absolute errors (scales can be tuned in config if desired)
        qed_err  = abs(qed_out  - qed_t)   / 0.5
        logp_err = abs(logp_out - logp_t)  / 3.0
        mw_err   = abs(mw_out   - mw_t)    / 50.0

        # wrong-direction penalties
        dir_pen_qed  = 2 if (qed_out  - qed_in) * d_qed   < 0 else 1
        dir_pen_logp = 2 if (logp_out - logp_in) * d_logp < 0 else 1
        dir_pen_mw   = 2 if (mw_out   - mw_in) * d_mw     < 0 else 1

        total_err = qed_err * dir_pen_qed + logp_err * dir_pen_logp + mw_err * dir_pen_mw
        reward = 1.0 - total_err
        out.append(reward)
    return out

# -----------------------------------------------------------------------------
# GRPO
# -----------------------------------------------------------------------------

def train_grpo(cfg: Dict[str, Any]) -> None:
    """
    GRPO with LoRA on top of either:
      - a prior SFT checkpoint (grpo.init_ckpt), or
      - the base model (fresh LoRA).
    Dataset must provide 'prompt' (or 'input' -> we wrap to 'prompt').
    """
    system_prompt = cfg.get("model", {}).get(
        "system_prompt",
        "You love and excel at editing SMILES strings to make original SMILES meet the required numeric properties.\n",
    )

    # Tokenizer: prefer tokenizer from init_ckpt if given, else base
    tokenizer_from = cfg.get("grpo", {}).get("tokenizer_path") or cfg.get("grpo", {}).get("init_ckpt")
    tokenizer = _build_tokenizer(cfg, pretrained_path=tokenizer_from)
    # (If tokenizer already had the specials from SFT, add_special_tokens() is a no-op)

    # Model (+LoRA or resume)
    model = _build_model_for_grpo(cfg, tokenizer)
    model.print_trainable_parameters()

    # Dataset (prompts)
    ds = _load_dataset_from_cfg(cfg, purpose="grpo")

    # Ensure ChemDFM-style prompt prefix and no duplicates
    def _reformat(ex):
        p = ex["prompt"] if "prompt" in ex else ex["input"]
        formatted = f"[Round 0]\nHuman: {system_prompt}{p}\nAssistant:"
        return {"prompt": formatted}
    ds = ds.map(_reformat, remove_columns=[c for c in ds.column_names if c != "prompt"])
    # Deduplicate
    seen = set()
    rows = []
    for ex in ds:
        if ex["prompt"] not in seen:
            seen.add(ex["prompt"])
            rows.append({"prompt": ex["prompt"]})
    ds = Dataset.from_list(rows)

    # GRPO config
    out_dir = cfg.get("grpo", {}).get("output_dir", os.path.join(cfg["project"]["exp_dir"], "grpo"))
    gen_kwargs = cfg.get("grpo", {}).get("rollout", {})
    grpo_args = GRPOConfig(
        output_dir=out_dir,
        per_device_train_batch_size=cfg["grpo"].get("per_device_train_batch_size", 4),
        gradient_accumulation_steps=cfg["grpo"].get("gradient_accumulation_steps", 8),
        dataloader_drop_last=True,
        dataloader_num_workers=cfg["grpo"].get("num_workers", 4),
        generation_kwargs={
            "do_sample": True,
            "temperature": gen_kwargs.get("temperature", 1.0),
            "top_p": gen_kwargs.get("top_p", 0.9),
            "top_k": gen_kwargs.get("top_k", 50),
            "max_new_tokens": gen_kwargs.get("max_new_tokens", 128),
        },
        disable_tqdm=False,
        max_steps=cfg["grpo"].get("max_steps", 1000),
        learning_rate=cfg["grpo"].get("lr", 3e-5),
        bf16=_bf16_supported(),
        logging_steps=cfg["grpo"].get("logging_steps", 50),
        save_steps=cfg["grpo"].get("save_steps", 500),
        save_total_limit=cfg["grpo"].get("save_total_limit", 10),
        report_to=cfg.get("logging", {}).get("report_to", "none"),
        logging_dir=cfg.get("logging", {}).get("dir", "./logs-grpo"),
        remove_unused_columns=False,
        # GRPO-specific knobs
        group_size=cfg["grpo"].get("group_size", 8),
        num_generations=cfg["grpo"].get("groups_per_step", 8),
        kl_coef=cfg["grpo"].get("kl_coef", 0.02),
    )

    # Build trainer with reward funcs (weights via list repetition if needed)
    reward_funcs = [
        strict_format_reward,
        multi_property_reward,
        repetition_penalty_reward,
    ]

    trainer = GRPOTrainer(
        model=model,
        args=grpo_args,
        train_dataset=ds,
        reward_funcs=reward_funcs,
    )

    trainer.train()
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)
