#!/usr/bin/env python3
"""
Collect responses exported from Google Forms and map to assignments.

Assumes one form per participant and that the exported CSV columns follow the
pattern:
  Q{n} Difference, Q{n} Difference explanation,
  Q{n} Relevance, Q{n} Relevance explanation,
  Q{n} Acknowledgement, Q{n} Acknowledgement explanation,
  Q{n} Refusal, Q{n} Refusal explanation
"""

import argparse
import csv
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple

try:
    from googleapiclient import discovery  # type: ignore
    from httplib2 import Http  # type: ignore
    from oauth2client import client, file, tools  # type: ignore
except Exception:  # pragma: no cover - optional dependency
    discovery = None
    Http = None
    client = None
    file = None
    tools = None

SCOPES = [
    "https://www.googleapis.com/auth/forms.body",
    "https://www.googleapis.com/auth/drive.readonly",
]
DISCOVERY_DOC = "https://forms.googleapis.com/$discovery/rest?version=v1"


QUESTION_FIELDS = [
    ("Difference", "difference"),
    ("Difference explanation", "difference_explanation"),
    ("Relevance", "relevance"),
    ("Relevance explanation", "relevance_explanation"),
    ("Acknowledgement", "acknowledgement"),
    ("Acknowledgement explanation", "acknowledgement_explanation"),
]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Collect Google Forms responses.")
    parser.add_argument(
        "--responses-csv",
        help="CSV exported from Google Forms responses. Optional if --use-forms-api is set.",
    )
    parser.add_argument(
        "--assignments",
        default="human_study/output/participants.jsonl",
        help="Path to participants.jsonl (assignment plan).",
    )
    parser.add_argument(
        "--selection",
        default="human_study/output/selection.jsonl",
        help="Path to selection.jsonl (question-model metadata).",
    )
    parser.add_argument(
        "--forms-created",
        default="human_study/output/forms/forms_created.json",
        help="Path to forms_created.json produced by create_forms.py (for API download).",
    )
    parser.add_argument(
        "--use-forms-api",
        action="store_true",
        help="If set, download latest responses from Google Forms API instead of reading a CSV.",
    )
    parser.add_argument(
        "--output-dir",
        default="human_study/output",
        help="Directory to write normalized responses.",
    )
    parser.add_argument(
        "--attention-check",
        default="human_study/attention_check.json",
        help="Path to the attention check JSON used when creating forms (to align positions).",
    )
    parser.add_argument(
        "--save-raw-responses",
        action="store_true",
        help="If set, save the latest raw response per participant to <output_dir>/responses_raw/.",
    )
    return parser.parse_args()


def read_jsonl(path: str) -> List[Dict]:
    rows: List[Dict] = []
    with Path(path).open("r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                rows.append(json.loads(line))
    return rows


def load_selection_map(path: str) -> Dict[str, Dict]:
    mapping: Dict[str, Dict] = {}
    with Path(path).open("r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                row = json.loads(line)
                mapping[row["id"]] = row
    return mapping


def load_attention_check(path: str) -> Optional[Dict]:
    p = Path(path)
    if not p.exists():
        return None
    try:
        data = json.loads(p.read_text(encoding="utf-8"))
    except Exception:
        return None
    return data if isinstance(data, dict) else None


def ensure_services():
    if not (discovery and Http and client and file and tools):
        raise ImportError(
            "google-api-python-client and oauth2client are required to call the Forms API."
        )
    base_dir = Path(__file__).resolve().parent
    token_path = base_dir / "token.json"
    secrets_path = base_dir / "client_secrets.json"
    store = file.Storage(str(token_path))
    creds = store.get()
    if not creds or creds.invalid:
        flow = client.flow_from_clientsecrets(str(secrets_path), SCOPES)
        flags = tools.argparser.parse_args([])
        creds = tools.run_flow(flow, store, flags)
    form_service = discovery.build(
        "forms",
        "v1",
        http=creds.authorize(Http()),
        discoveryServiceUrl=DISCOVERY_DOC,
        static_discovery=False,
    )
    return form_service


def load_forms_created(path: str) -> Dict[str, Tuple[str, Optional[int]]]:
    """Map participant_id -> (form_id, attention_index) from forms_created.json."""
    mapping: Dict[str, Tuple[str, Optional[int]]] = {}
    p = Path(path)
    if not p.exists():
        return mapping
    data = json.loads(p.read_text(encoding="utf-8"))
    if isinstance(data, list):
        for row in data:
            pid = row.get("participant_id")
            fid = row.get("form_id")
            att_idx = row.get("attention_index")
            if pid and fid:
                mapping[pid] = (fid, att_idx)
    return mapping


def _extract_answer_value(ans: Dict) -> Optional[str]:
    if not ans:
        return None
    if "textAnswers" in ans:
        texts = ans["textAnswers"].get("answers") or []
        if texts:
            return texts[0].get("value")
    if "choiceAnswers" in ans:
        choices = ans["choiceAnswers"].get("values") or []
        if choices:
            return choices[0]
    return None


def _parse_numeric_prefix(value: Optional[str]) -> Optional[float]:
    if value is None:
        return None
    if isinstance(value, (int, float)):
        return float(value)
    if not isinstance(value, str):
        return None
    digits = ""
    for ch in value.strip():
        if ch.isdigit() or (ch == "-" and not digits):
            digits += ch
        else:
            break
    try:
        return float(digits) if digits else None
    except ValueError:
        return None


def _build_question_lookup(form_meta: Dict) -> Dict[str, Tuple[int, str]]:
    """
    Map questionId -> (q_idx, label_key) using the order of question items.
    q_idx is 1-based position in the form, matching the item_ids order we used to build the form.
    """
    import re

    current_q = 0
    lookup: Dict[str, Tuple[int, str]] = {}
    for item in form_meta.get("items", []):
        # Advance question counter only on section headers that look like "Qn"
        if "sectionHeaderItem" in item or "pageBreakItem" in item:
            title = item.get("title", "") or ""
            match = re.search(r"Q(\d+)", title)
            if match:
                current_q = int(match.group(1))
            continue

        if "questionItem" not in item:
            continue
        q_item = item.get("questionItem")
        qid = item.get("questionItem", {}).get("question", {}).get("questionId")
        title = item.get("title", "") or ""
        # If we haven't entered a question section yet, skip (likely consent/intro items)
        if current_q == 0:
            current_q = 1
        for label, key in QUESTION_FIELDS:
            if title.startswith(label):
                lookup[qid] = (current_q, key)
                break
    return lookup


def download_responses(form_service, form_id: str) -> List[Dict]:
    resp = form_service.forms().responses().list(formId=form_id).execute()
    return resp.get("responses", []) or []


def pick_latest_response(responses: List[Dict]) -> Optional[Dict]:
    if not responses:
        return None
    return sorted(
        responses,
        key=lambda r: r.get("lastSubmittedTime") or r.get("createTime") or "",
        reverse=True,
    )[0]


def inject_attention_check(
    item_ids: List[str], attention_id: Optional[str], insert_pos: Optional[int] = None
) -> Tuple[List[str], int]:
    """Insert attention check, defaulting to just before halfway through the list."""
    if not attention_id:
        return item_ids, -1
    new_ids = list(item_ids)
    if not new_ids:
        return [attention_id], 0
    if insert_pos is None:
        insert_pos = max(1, len(new_ids) // 2)
    insert_pos = max(0, min(insert_pos, len(new_ids)))
    new_ids.insert(insert_pos, attention_id)
    return new_ids, insert_pos


def main() -> None:
    args = parse_args()
    if not args.use_forms_api and not args.responses_csv:
        raise RuntimeError("Provide --responses-csv or enable --use-forms-api.")
    participants = read_jsonl(args.assignments)
    selection = load_selection_map(args.selection)
    attention_check = load_attention_check(args.attention_check)
    if attention_check and "id" in attention_check:
        selection[attention_check["id"]] = attention_check

    api_form_map: Dict[str, Tuple[str, Optional[int]]] = {}
    form_service = None
    if args.use_forms_api:
        api_form_map = load_forms_created(args.forms_created)
        if not api_form_map:
            raise RuntimeError(
                f"No form mapping found at {args.forms_created}. Run create_forms.py with --perform-api first."
            )
        form_service = ensure_services()

    # We assume the CSV rows align to one participant each (order-insensitive if you add a participant_id column).
    reader: List[Dict[str, str]] = []
    if args.responses_csv:
        with Path(args.responses_csv).open("r", encoding="utf-8") as f:
            reader = list(csv.DictReader(f))

    normalized: List[Dict] = []

    attention_id = attention_check.get("id") if attention_check else None
    raw_dir = Path(args.output_dir) / "responses_raw"
    if args.save_raw_responses:
        raw_dir.mkdir(parents=True, exist_ok=True)

    for idx, participant in enumerate(participants):
        att_idx_override: Optional[int] = None
        if args.use_forms_api and participant["participant_id"] in api_form_map:
            _, att_idx_override = api_form_map[participant["participant_id"]]
        item_ids, _ = inject_attention_check(
            participant["item_ids"], attention_id, att_idx_override
        )
        # Pull API answers if enabled
        answers_by_key: Dict[Tuple[int, str], Optional[str]] = {}
        if args.use_forms_api and form_service:
            form_entry = api_form_map.get(participant["participant_id"])
            form_id = form_entry[0] if isinstance(form_entry, tuple) else form_entry
            if form_id:
                print(
                    f"Downloading responses for participant {participant['participant_id']} from form {form_id}"
                )
                form_meta = form_service.forms().get(formId=form_id).execute()
                q_lookup = _build_question_lookup(form_meta)
                latest = pick_latest_response(download_responses(form_service, form_id))
                if latest and args.save_raw_responses:
                    (raw_dir / f"{participant['participant_id']}.json").write_text(
                        json.dumps(latest, indent=2), encoding="utf-8"
                    )
                if latest and "answers" in latest:
                    for qid, ans in latest["answers"].items():
                        if qid in q_lookup:
                            q_idx, key = q_lookup[qid]
                            value = _extract_answer_value(ans)
                            if value is None:
                                continue
                            existing = answers_by_key.get((q_idx, key))
                            if existing is None:
                                answers_by_key[(q_idx, key)] = value
                            else:
                                # Prefer keeping an existing numeric rating; if we only have text
                                # and the new value parses as numeric, upgrade instead of overwriting.
                                existing_num = _parse_numeric_prefix(existing)
                                new_num = _parse_numeric_prefix(value)
                                if existing_num is None and new_num is not None:
                                    answers_by_key[(q_idx, key)] = value

        # Otherwise fallback to CSV row by position
        csv_row = reader[idx] if idx < len(reader) else {}

        for q_idx, item_id in enumerate(item_ids, start=1):
            record = {
                "participant_id": participant["participant_id"],
                "item_id": item_id,
                "question_id": selection.get(item_id, {}).get("question_id"),
                "model_id": selection.get(item_id, {}).get("model_id"),
            }
            prefix = f"Q{q_idx} "
            for label, key in QUESTION_FIELDS:
                value = answers_by_key.get((q_idx, key))
                if value is None and csv_row:
                    col = prefix + label
                    value = csv_row.get(col)
                num_value = _parse_numeric_prefix(value)
                record[f"{key}_raw"] = value
                record[key] = num_value if num_value is not None else value
            normalized.append(record)

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / "responses_normalized.jsonl"
    with out_path.open("w", encoding="utf-8") as f:
        for rec in normalized:
            f.write(json.dumps(rec) + "\n")

    print(f"Wrote {len(normalized)} response rows to {out_path}")


if __name__ == "__main__":
    main()
