import os, json, time, traceback, argparse
from pathlib import Path

import pandas as pd
from dotenv import load_dotenv
from tqdm.auto import tqdm

import anthropic

# ───────────────────────────────────────────────────────────────────────
# 0.  ENV + CLI
# ---------------------------------------------------------------------
load_dotenv()
API_KEY = os.getenv("ANTHROPIC_API_KEY")
assert API_KEY, "Set ANTHROPIC_API_KEY in .env or shell!"

ap = argparse.ArgumentParser()
ap.add_argument("--txt_dir",  required=True,  type=Path)
ap.add_argument("--out_dir",  required=True,  type=Path)
ap.add_argument("--sleep",    default=5.0,    type=float,  # Increased default sleep
                help="seconds to wait between calls")
args = ap.parse_args()

args.out_dir.mkdir(parents=True, exist_ok=True)
OUT_CSV = args.out_dir / "llm_2stage_fc_predictions_sonnet.csv"
ERR_CSV = args.out_dir / "llm_2stage_fc_errors_sonnet.csv"

# ───────────────────────────────────────────────────────────────────────
# 1.  Claude client
# ---------------------------------------------------------------------
client = anthropic.Anthropic(api_key=API_KEY)
MODEL  = "claude-3-7-sonnet-20241022"

# --- STAGE-1 tool ----------------------------------------------------
extract_tool = {
    "name": "extract_pediatric_summary",
    "description": (
        "Scan an FDA product label and return every piece of pediatric evidence "
        "needed to judge extrapolation."
    ),
    "input_schema": {
        "type": "object",
        "properties": {
            "PediatricSummary": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "section": {"type": "string"},
                        "summary": {"type": "string"},
                    },
                    "required": ["section", "summary"],
                },
            },
            "AllAges":  {"type": "array", "items": {"type": "string"}},
            "Comments": {"type": "string"},
        },
        "required": ["PediatricSummary", "AllAges", "Comments"],
    },
}

# --- STAGE-2 tool ----------------------------------------------------
classify_tool = {
    "name": "classify_extrapolation",
    "description": "Given the pediatric summary, decide extrapolation category.",
    "input_schema": {
        "type": "object",
        "properties": {
            "resolved_label": {
                "type": "string",
                "enum": ["None", "Partial", "Full", "Unlabeled"],
            },
            "peds_study_type": {
                "type": "string",
                "enum": ["RCT", "PK+Safety", "PK Only", "None"],
            },
            "efficacy_summary": {"type": "string"},
            "pk_summary":        {"type": "string"},
            "lowest_age_band":   {"type": "string"},
            "highest_age_band":  {"type": "string"},
            "rationale":         {"type": "string"},
            "confidence": {
                "type": "string",
                "enum": ["high", "medium", "low"],
            },
        },
        "required": [
            "resolved_label", "peds_study_type", "efficacy_summary",
            "pk_summary", "lowest_age_band", "highest_age_band",
            "rationale", "confidence",
        ],
    },
}

EXTRACT_SYSTEM  = (
    "You are scanning an FDA product label. "
    "Return JSON ONLY via the function call. Follow the schema exactly. "
    "• Summaries ≤150 words. • DO NOT invent data."
)

CLASSIFY_SYSTEM = (
    "You are an expert in FDA pediatric extrapolation. "
    "Use the decision tree:\n"
    "• None – ≥1 pediatric efficacy RCT\n"
    "• Partial – pediatric PK and/or safety evidence but NO efficacy RCT\n"
    "• Full – only PK / exposure modelling; no pediatric safety cohort\n"
    "• Unlabeled – no pediatric evidence\n\n"
    "Return JSON ONLY via the function call."
)

# ────────────────────────────────────────────────────────────────────────
def truncate_text(text, max_chars=50000):
    """Truncate text to avoid rate limits and timeouts"""
    if len(text) <= max_chars:
        return text
    
    # Try to truncate at a reasonable break point
    truncated = text[:max_chars]
    
    # Find last sentence or paragraph break
    for break_char in ['\n\n', '\n', '. ', '! ', '? ']:
        last_break = truncated.rfind(break_char)
        if last_break > max_chars * 0.8:  # At least 80% of the text
            return truncated[:last_break + len(break_char)] + "\n\n[Document truncated due to length]"
    
    return truncated + "\n\n[Document truncated due to length]"


def call_claude_with_retry(user_content, system_content, tools, forced_name, max_retries=3):
    """
    Call Claude with retry logic for rate limits and streaming requirements
    """
    # Truncate input to avoid rate limits
    user_content = truncate_text(user_content)
    
    for attempt in range(max_retries):
        try:
            # First try regular API call
            resp = client.messages.create(
                model       = MODEL,
                system      = system_content,
                messages    = [{"role": "user", "content": user_content}],
                tools       = tools,
                tool_choice = {"type": "tool", "name": forced_name},
                max_tokens  = 4096,
            )
            
            # Extract the tool use from the response
            for content_block in resp.content:
                if content_block.type == "tool_use":
                    return content_block.input
            
            raise ValueError("No tool use found in response")
            
        except anthropic.RateLimitError as e:
            wait_time = 60 * (attempt + 1)  # Exponential backoff: 60s, 120s, 180s
            print(f"Rate limit hit, waiting {wait_time}s before retry {attempt + 1}/{max_retries}")
            time.sleep(wait_time)
            if attempt == max_retries - 1:
                raise e
                
        except Exception as e:
            if "Streaming is required" in str(e):
                # Try with streaming
                try:
                    return call_claude_streaming(user_content, system_content, tools, forced_name)
                except Exception as stream_e:
                    if attempt == max_retries - 1:
                        raise stream_e
                    time.sleep(30)  # Wait before retry
            else:
                if attempt == max_retries - 1:
                    raise e
                time.sleep(10)  # Short wait for other errors


def call_claude_streaming(user_content, system_content, tools, forced_name):
    """
    Streaming version for very large inputs
    """
    with client.messages.stream(
        model       = MODEL,
        system      = system_content,
        messages    = [{"role": "user", "content": user_content}],
        tools       = tools,
        tool_choice = {"type": "tool", "name": forced_name},
        max_tokens  = 4096,
    ) as stream:
        # Get the final message
        final_message = stream.get_final_message()
        
        # Extract tool use from final message
        for content_block in final_message.content:
            if content_block.type == "tool_use":
                return content_block.input
        
        raise ValueError("No tool use found in streaming response")


# ────────────────────────────────────────────────────────────────────────
def main(txt_dir: Path, out_dir: Path, delay: float):
    out_dir.mkdir(parents=True, exist_ok=True)
    preds_file = out_dir / "llm_2stage_fc_predictions_sonnet.csv"
    errs_file  = out_dir / "llm_2stage_fc_errors_sonnet.csv"

    records, errors = [], []

    for txt_path in tqdm(sorted(txt_dir.glob("*.txt")), desc="Files"):
        app_id = txt_path.stem
        text   = txt_path.read_text("utf-8", errors="ignore")

        try:
            # 1️⃣  extraction
            s1 = call_claude_with_retry(
                user_content=text,
                system_content=EXTRACT_SYSTEM,
                tools=[extract_tool],
                forced_name="extract_pediatric_summary",
            )

            # Wait between stages to respect rate limits
            time.sleep(delay)

            # 2️⃣  classification
            s2 = call_claude_with_retry(
                user_content=json.dumps(s1, separators=(",",":")),
                system_content=CLASSIFY_SYSTEM,
                tools=[classify_tool],
                forced_name="classify_extrapolation",
            )

            records.append({
                "app_id": app_id,
                **s2,
                "summary_json": json.dumps(s1, separators=(",",":")),
                "txt_file": str(txt_path)
            })
            print(f"✓ {app_id:>14} → {s2['resolved_label']:9} ({s2['confidence']})")

        except Exception as exc:
            traceback.print_exc(limit=1)
            errors.append({"app_id": app_id, "error": str(exc)})
            print(f"✗ {app_id} FAILED – {exc}")

        # Wait between files to respect rate limits
        time.sleep(delay)

    # save results
    if records:
        pd.DataFrame(records).to_csv(preds_file, index=False)
        print(f"\n✔ Saved {len(records)} rows → {preds_file}")
    pd.DataFrame(errors).to_csv(errs_file, index=False)
    print(f"Finished: {len(records)} OK, {len(errors)} errors → {errs_file}")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--txt_dir", required=True, type=Path)
    p.add_argument("--out_dir", required=True, type=Path)
    p.add_argument("--sleep",  default=5.0, type=float, help="delay between calls (s)")
    args = p.parse_args()
    main(args.txt_dir, args.out_dir, args.sleep)