
import os, json, time, traceback, argparse
from pathlib import Path

import pandas as pd
from dotenv import load_dotenv
from tqdm.auto import tqdm

import google.generativeai as genai

# ───────────────────────────────────────────────────────────────────────
# 0.  ENV + CLI
# ---------------------------------------------------------------------
load_dotenv()
API_KEY = os.getenv("GOOGLE_API_KEY")
assert API_KEY and API_KEY != "your-google-api-key-here", "Set GOOGLE_API_KEY in .env or shell!"

ap = argparse.ArgumentParser()
ap.add_argument("--txt_dir",  required=True,  type=Path)
ap.add_argument("--out_dir",  required=True,  type=Path)
ap.add_argument("--sleep",    default=2.0,    type=float,
                help="seconds to wait between calls")
args = ap.parse_args()

args.out_dir.mkdir(parents=True, exist_ok=True)
OUT_CSV = args.out_dir / "llm_2stage_fc_predictions_gemini.csv"
ERR_CSV = args.out_dir / "llm_2stage_fc_errors_gemini.csv"

# ───────────────────────────────────────────────────────────────────────
# 1.  Google Gemini setup
# ---------------------------------------------------------------------
genai.configure(api_key=API_KEY)

# Use Gemini 1.5 Pro for best performance
MODEL_NAME = "gemini-2.5-pro"

# Configure the model
generation_config = {
    "temperature": 0.1,
    "top_p": 0.95,
    "top_k": 64,
    "max_output_tokens": 8192,
}

safety_settings = [
    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
]

model = genai.GenerativeModel(
    model_name=MODEL_NAME,
    generation_config=generation_config,
    safety_settings=safety_settings
)

# --- Function definitions for structured output ----------------------------------------------------
extract_schema = {
    "type": "object",
    "properties": {
        "PediatricSummary": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "section": {"type": "string"},
                    "summary": {"type": "string"},
                },
                "required": ["section", "summary"],
            },
        },
        "AllAges": {"type": "array", "items": {"type": "string"}},
        "Comments": {"type": "string"},
    },
    "required": ["PediatricSummary", "AllAges", "Comments"],
}

classify_schema = {
    "type": "object",
    "properties": {
        "resolved_label": {
            "type": "string",
            "enum": ["None", "Partial", "Full", "Unlabeled"],
        },
        "peds_study_type": {
            "type": "string",
            "enum": ["RCT", "PK+Safety", "PK Only", "None"],
        },
        "efficacy_summary": {"type": "string"},
        "pk_summary": {"type": "string"},
        "lowest_age_band": {"type": "string"},
        "highest_age_band": {"type": "string"},
        "rationale": {"type": "string"},
        "confidence": {
            "type": "string",
            "enum": ["high", "medium", "low"],
        },
    },
    "required": [
        "resolved_label", "peds_study_type", "efficacy_summary",
        "pk_summary", "lowest_age_band", "highest_age_band",
        "rationale", "confidence",
    ],
}

EXTRACT_SYSTEM = (
    "You are scanning an FDA product label. "
    "Return JSON ONLY in the exact format specified. Follow the schema exactly. "
    "• Summaries ≤150 words. • DO NOT invent data. "
    "• Return ONLY valid JSON, no other text."
)

CLASSIFY_SYSTEM = (
    "You are an expert in FDA pediatric extrapolation. "
    "Use the decision tree:\n"
    "• None – ≥1 pediatric efficacy RCT\n"
    "• Partial – pediatric PK and/or safety evidence but NO efficacy RCT\n"
    "• Full – only PK / exposure modelling; no pediatric safety cohort\n"
    "• Unlabeled – no pediatric evidence\n\n"
    "Return ONLY valid JSON in the exact format specified, no other text."
)

# ────────────────────────────────────────────────────────────────────────
def truncate_text(text, max_chars=150000):
    """Truncate text to avoid context limits"""
    if len(text) <= max_chars:
        return text
    
    # Try to truncate at a reasonable break point
    truncated = text[:max_chars]
    
    # Find last sentence or paragraph break
    for break_char in ['\n\n', '\n', '. ', '! ', '? ']:
        last_break = truncated.rfind(break_char)
        if last_break > max_chars * 0.8:  # At least 80% of the text
            return truncated[:last_break + len(break_char)] + "\n\n[Document truncated due to length]"
    
    return truncated + "\n\n[Document truncated due to length]"


def extract_json_from_response(text):
    """Extract JSON from Gemini response"""
    text = text.strip()
    
    # Remove markdown code blocks if present
    if text.startswith('```json'):
        text = text[7:]
    if text.startswith('```'):
        text = text[3:]
    if text.endswith('```'):
        text = text[:-3]
    
    text = text.strip()
    
    # Find JSON object
    start = text.find('{')
    end = text.rfind('}') + 1
    
    if start >= 0 and end > start:
        json_str = text[start:end]
        return json.loads(json_str)
    else:
        raise ValueError(f"No valid JSON found in response: {text}")


def call_gemini_with_retry(user_content, system_content, schema, max_retries=3):
    """
    Call Gemini with retry logic
    """
    # Truncate input to avoid context limits
    user_content = truncate_text(user_content)
    
    prompt = f"{system_content}\n\nRequired JSON Schema:\n{json.dumps(schema, indent=2)}\n\nContent to analyze:\n{user_content}\n\nReturn ONLY valid JSON:"
    
    for attempt in range(max_retries):
        try:
            response = model.generate_content(prompt)
            
            # Handle safety blocks
            if response.candidates[0].finish_reason.name == "SAFETY":
                raise ValueError("Response blocked by safety filters")
            
            response_text = response.text
            return extract_json_from_response(response_text)
            
        except genai.types.BlockedPromptException:
            raise ValueError("Prompt blocked by safety filters")
            
        except genai.types.StopCandidateException:
            raise ValueError("Response generation stopped")
            
        except json.JSONDecodeError as e:
            print(f"JSON decode error on attempt {attempt + 1}: {e}")
            if attempt == max_retries - 1:
                raise ValueError(f"Failed to parse JSON after {max_retries} attempts: {e}")
            time.sleep(2)
            
        except Exception as e:
            print(f"Error on attempt {attempt + 1}: {e}")
            if attempt == max_retries - 1:
                raise e
            wait_time = 5 * (attempt + 1)
            print(f"Waiting {wait_time}s before retry...")
            time.sleep(wait_time)


# ────────────────────────────────────────────────────────────────────────
def main(txt_dir: Path, out_dir: Path, delay: float):
    out_dir.mkdir(parents=True, exist_ok=True)
    preds_file = out_dir / "llm_2stage_fc_predictions_gemini.csv"
    errs_file  = out_dir / "llm_2stage_fc_errors_gemini.csv"

    records, errors = [], []

    # Test API connection first
    print("Testing Gemini API connection...")
    try:
        test_response = model.generate_content("Hello, respond with 'API Working'")
        print(f"✓ API connection successful! Response: {test_response.text}")
    except Exception as e:
        print(f"✗ API connection failed: {e}")
        print("Please check your API key and try again.")
        return

    for txt_path in tqdm(sorted(txt_dir.glob("*.txt")), desc="Files"):
        app_id = txt_path.stem
        text   = txt_path.read_text("utf-8", errors="ignore")

        try:
            # 1️⃣  extraction
            s1 = call_gemini_with_retry(
                user_content=text,
                system_content=EXTRACT_SYSTEM,
                schema=extract_schema,
            )

            # Wait between stages to respect rate limits
            time.sleep(delay)

            # 2️⃣  classification
            s2 = call_gemini_with_retry(
                user_content=json.dumps(s1, separators=(",",":")),
                system_content=CLASSIFY_SYSTEM,
                schema=classify_schema,
            )

            records.append({
                "app_id": app_id,
                **s2,
                "summary_json": json.dumps(s1, separators=(",",":")),
                "txt_file": str(txt_path)
            })
            print(f"✓ {app_id:>14} → {s2['resolved_label']:9} ({s2['confidence']})")

        except Exception as exc:
            traceback.print_exc(limit=1)
            errors.append({"app_id": app_id, "error": str(exc)})
            print(f"✗ {app_id} FAILED – {exc}")

        # Wait between files to respect rate limits
        time.sleep(delay)

    # save results
    if records:
        pd.DataFrame(records).to_csv(preds_file, index=False)
        print(f"\n✔ Saved {len(records)} rows → {preds_file}")
    if errors:
        pd.DataFrame(errors).to_csv(errs_file, index=False)
        print(f"Errors saved to → {errs_file}")
    print(f"Finished: {len(records)} OK, {len(errors)} errors")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--txt_dir", required=True, type=Path)
    p.add_argument("--out_dir", required=True, type=Path)
    p.add_argument("--sleep",  default=2.0, type=float, help="delay between calls (s)")
    args = p.parse_args()
    main(args.txt_dir, args.out_dir, args.sleep)