#!/usr/bin/env python3
"""
Convert TrialBench trial-duration-forecasting data to SPR-RAFT format.

SPR-RAFT format (JSONL):
{
    "id": str,
    "input": str,
    "y": float,
    "y_text": str (optional)
}

The target is trial duration in years (y column from train_y.csv).
"""

import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional
import argparse


def clean_text(text: str) -> str:
    """Clean text by removing extra whitespace and newlines."""
    if pd.isna(text) or text == '':
        return ''
    text = str(text)
    # Replace multiple whitespace with single space
    text = ' '.join(text.split())
    return text.strip()


def format_list_field(value) -> str:
    """Format list-like fields from string representation."""
    if pd.isna(value) or value == '' or value == '[]' or value == "['None']":
        return ''
    if isinstance(value, str):
        # Try to parse as list
        try:
            import ast
            parsed = ast.literal_eval(value)
            if isinstance(parsed, list):
                # Filter out 'None' values
                parsed = [x for x in parsed if x and x != 'None']
                return ', '.join(str(x) for x in parsed)
        except:
            pass
    return clean_text(str(value))


def create_trial_prompt(row: pd.Series, nct_id: str) -> str:
    """
    Create a structured prompt from trial data for duration prediction.

    The prompt should contain enough information for the model to predict
    how long the trial will take.
    """
    parts = []

    # Title and summary
    title = clean_text(row.get('brief_title', ''))
    if title:
        parts.append(f"Trial: {title}")

    summary = clean_text(row.get('brief_summary/textblock', ''))
    if summary:
        # Truncate if too long
        if len(summary) > 500:
            summary = summary[:500] + "..."
        parts.append(f"Summary: {summary}")

    # Condition being studied
    condition = clean_text(row.get('condition', ''))
    if condition:
        parts.append(f"Condition: {condition}")

    # Phase
    phase = clean_text(row.get('phase', ''))
    if phase:
        parts.append(f"Phase: {phase}")

    # Study type and design
    study_type = clean_text(row.get('study_type', ''))
    if study_type:
        parts.append(f"Study Type: {study_type}")

    allocation = clean_text(row.get('study_design_info/allocation', ''))
    if allocation:
        parts.append(f"Allocation: {allocation}")

    intervention_model = clean_text(row.get('study_design_info/intervention_model', ''))
    if intervention_model:
        parts.append(f"Intervention Model: {intervention_model}")

    masking = clean_text(row.get('study_design_info/masking', ''))
    if masking:
        parts.append(f"Masking: {masking}")

    primary_purpose = clean_text(row.get('study_design_info/primary_purpose', ''))
    if primary_purpose:
        parts.append(f"Primary Purpose: {primary_purpose}")

    # Number of arms
    num_arms = row.get('number_of_arms', np.nan)
    if pd.notna(num_arms):
        parts.append(f"Number of Arms: {int(num_arms)}")

    # Interventions
    interventions = format_list_field(row.get('intervention/intervention_name', ''))
    if interventions:
        parts.append(f"Interventions: {interventions}")

    # Eligibility
    gender = clean_text(row.get('eligibility/gender', ''))
    if gender:
        parts.append(f"Eligible Gender: {gender}")

    min_age = clean_text(row.get('eligibility/minimum_age', ''))
    max_age = clean_text(row.get('eligibility/maximum_age', ''))
    if min_age or max_age:
        age_range = f"{min_age} - {max_age}" if min_age and max_age else (min_age or max_age)
        parts.append(f"Age Range: {age_range}")

    healthy_volunteers = clean_text(row.get('eligibility/healthy_volunteers', ''))
    if healthy_volunteers:
        parts.append(f"Healthy Volunteers: {healthy_volunteers}")

    # Sponsor class
    sponsor_class = clean_text(row.get('sponsors/lead_sponsor/agency_class', ''))
    if sponsor_class:
        parts.append(f"Sponsor Class: {sponsor_class}")

    # Create the prompt
    prompt = "\n".join(parts)

    # Add task instruction
    instruction = (
        "Based on the clinical trial information above, predict the duration of this trial in years.\n"
        "Think step by step about factors that might affect trial duration, "
        "then provide your prediction."
    )

    return f"{prompt}\n\n{instruction}"


def format_y(y: float, decimals: int = 4) -> str:
    """Format y value with consistent decimal places."""
    return f"{y:.{decimals}f}"


def convert_trialbench_to_spr_raft(
    data_dir: str,
    output_path: str,
    split: str = "train",
    phase: str = "Phase3",
    max_samples: Optional[int] = None,
    decimals: int = 4,
    target_col: str = "year",
    seed: int = 42
):
    """
    Convert TrialBench trial-duration-forecasting data to SPR-RAFT format.

    Args:
        data_dir: Path to TrialBench directory
        output_path: Output JSONL file path
        split: 'train' or 'test'
        phase: Phase subdirectory (Phase1, Phase2, Phase3, Phase4)
        max_samples: Maximum number of samples (None for all)
        decimals: Number of decimal places for y formatting
        target_col: Target column in train_y.csv ('year', 'month', or 'time_day')
        seed: Random seed for sampling
    """
    data_dir = Path(data_dir)
    phase_dir = data_dir / "trial-duration-forecasting" / phase

    # Load data
    x_path = phase_dir / f"{split}_x.csv"
    y_path = phase_dir / f"{split}_y.csv"

    print(f"Loading data from {phase_dir}")
    train_x = pd.read_csv(x_path, index_col=0)
    train_y = pd.read_csv(y_path, index_col=0)

    print(f"Loaded {len(train_x)} samples")

    # Merge x and y
    data = train_x.join(train_y[[target_col]], how='inner')
    print(f"After join: {len(data)} samples")

    # Sample if needed
    if max_samples and max_samples < len(data):
        data = data.sample(n=max_samples, random_state=seed)
        print(f"Sampled {max_samples} samples")

    # Convert to SPR-RAFT format
    samples = []
    for nct_id, row in data.iterrows():
        y_value = row[target_col]

        # Skip if y is invalid
        if pd.isna(y_value):
            continue

        prompt = create_trial_prompt(row, nct_id)
        y_text = format_y(y_value, decimals)

        sample = {
            "id": nct_id,
            "input": prompt,
            "y": float(y_value),
            "y_text": y_text
        }
        samples.append(sample)

    # Write to JSONL
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, 'w') as f:
        for sample in samples:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')

    print(f"Wrote {len(samples)} samples to {output_path}")

    # Print statistics
    y_values = [s['y'] for s in samples]
    print(f"\nTarget statistics ({target_col}):")
    print(f"  Min: {min(y_values):.4f}")
    print(f"  Max: {max(y_values):.4f}")
    print(f"  Mean: {np.mean(y_values):.4f}")
    print(f"  Std: {np.std(y_values):.4f}")

    return samples


def convert_to_sharegpt_format(
    input_path: str,
    output_path: str,
    include_cot: bool = True
):
    """
    Convert SPR-RAFT JSONL to ShareGPT format for LLaMA-Factory.

    ShareGPT format:
    {
        "conversations": [
            {"from": "human", "value": "..."},
            {"from": "gpt", "value": "... [REG] 1.2345"}
        ]
    }
    """
    samples = []
    with open(input_path, 'r') as f:
        for line in f:
            samples.append(json.loads(line))

    sharegpt_data = []
    for sample in samples:
        # Human message
        human_msg = sample['input']

        # Assistant response with [REG] token
        if include_cot:
            # Simple CoT template
            response = (
                f"Let me analyze the key factors affecting trial duration:\n"
                f"- The trial phase and study design influence timeline\n"
                f"- Enrollment criteria and target population affect recruitment\n"
                f"- Intervention complexity impacts study duration\n\n"
                f"Based on these factors, I predict the trial duration:\n"
                f"[REG] {sample['y_text']}"
            )
        else:
            response = f"[REG] {sample['y_text']}"

        sharegpt_sample = {
            "id": sample.get('id', ''),
            "conversations": [
                {"from": "human", "value": human_msg},
                {"from": "gpt", "value": response}
            ],
            "y": sample['y']  # Keep y for regression
        }
        sharegpt_data.append(sharegpt_sample)

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, 'w') as f:
        for sample in sharegpt_data:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')

    print(f"Wrote {len(sharegpt_data)} ShareGPT samples to {output_path}")

    return sharegpt_data


def main():
    parser = argparse.ArgumentParser(description="Convert TrialBench to SPR-RAFT format")
    parser.add_argument(
        "--data_dir",
        type=str,
        default="/data1/jinxiang/bio-reg-gpt/TrialBench",
        help="Path to TrialBench directory"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="/data1/jinxiang/bio-reg-gpt/data/spr_raft_trialbench",
        help="Output directory"
    )
    parser.add_argument(
        "--phase",
        type=str,
        default="Phase3",
        choices=["Phase1", "Phase2", "Phase3", "Phase4"],
        help="Trial phase"
    )
    parser.add_argument(
        "--max_train",
        type=int,
        default=None,
        help="Max training samples"
    )
    parser.add_argument(
        "--max_test",
        type=int,
        default=None,
        help="Max test samples"
    )
    parser.add_argument(
        "--toy",
        action="store_true",
        help="Create toy dataset (50 train, 20 test)"
    )
    parser.add_argument(
        "--target",
        type=str,
        default="year",
        choices=["year", "month", "time_day"],
        help="Target column"
    )
    parser.add_argument(
        "--include_cot",
        action="store_true",
        default=True,
        help="Include CoT in response"
    )
    parser.add_argument(
        "--val_split",
        type=float,
        default=0.1,
        help="Validation split ratio (0-1). If >0, split train data into train/val"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for shuffling and splitting"
    )

    args = parser.parse_args()

    output_dir = Path(args.output_dir)

    if args.toy:
        args.max_train = 50
        args.max_test = 20
        output_dir = output_dir / "toy"
        args.val_split = 0.0  # No validation for toy dataset

    # Convert train split
    print("\n=== Converting training data ===")
    train_raw_path = output_dir / "train_raw.jsonl"
    train_samples = convert_trialbench_to_spr_raft(
        data_dir=args.data_dir,
        output_path=str(train_raw_path),
        split="train",
        phase=args.phase,
        max_samples=args.max_train,
        target_col=args.target,
        seed=args.seed
    )

    # Split into train/val if requested
    if args.val_split > 0:
        import random
        random.seed(args.seed)
        np.random.seed(args.seed)

        # Shuffle and split
        random.shuffle(train_samples)
        n_val = int(len(train_samples) * args.val_split)
        val_samples = train_samples[:n_val]
        train_samples = train_samples[n_val:]

        print(f"\n=== Splitting train data ===")
        print(f"Total samples: {len(train_samples) + len(val_samples)}")
        print(f"Train samples: {len(train_samples)}")
        print(f"Val samples: {len(val_samples)}")

        # Write train subset
        with open(train_raw_path, 'w') as f:
            for sample in train_samples:
                f.write(json.dumps(sample, ensure_ascii=False) + '\n')

        # Write validation set
        val_raw_path = output_dir / "val_raw.jsonl"
        with open(val_raw_path, 'w') as f:
            for sample in val_samples:
                f.write(json.dumps(sample, ensure_ascii=False) + '\n')
        print(f"Wrote validation set to {val_raw_path}")

        # Convert validation to ShareGPT format
        val_sharegpt_path = output_dir / "val.jsonl"
        convert_to_sharegpt_format(
            input_path=str(val_raw_path),
            output_path=str(val_sharegpt_path),
            include_cot=args.include_cot
        )

    # Convert to ShareGPT format
    train_sharegpt_path = output_dir / "train.jsonl"
    convert_to_sharegpt_format(
        input_path=str(train_raw_path),
        output_path=str(train_sharegpt_path),
        include_cot=args.include_cot
    )

    # Convert test split
    print("\n=== Converting test data ===")
    test_raw_path = output_dir / "test_raw.jsonl"
    convert_trialbench_to_spr_raft(
        data_dir=args.data_dir,
        output_path=str(test_raw_path),
        split="test",
        phase=args.phase,
        max_samples=args.max_test,
        target_col=args.target,
        seed=args.seed
    )

    # Convert to ShareGPT format
    test_sharegpt_path = output_dir / "test.jsonl"
    convert_to_sharegpt_format(
        input_path=str(test_raw_path),
        output_path=str(test_sharegpt_path),
        include_cot=args.include_cot
    )

    print(f"\n=== Done ===")
    print(f"Output directory: {output_dir}")
    print(f"  train.jsonl: ShareGPT format for training ({len(train_samples) if args.val_split > 0 else 'all'} samples)")
    if args.val_split > 0:
        print(f"  val.jsonl: ShareGPT format for validation ({len(val_samples)} samples)")
    print(f"  test.jsonl: ShareGPT format for evaluation")
    print(f"  *_raw.jsonl: Raw SPR-RAFT format files")


if __name__ == "__main__":
    main()
