"""
Script to obtain sample trajectories that summarize a discovered objective.

Implements Section 4.3 of Methods.md:
- Computes trend fidelity for each trajectory: fid(ξ) = exp(- Σ(u_t - f*(t))^2)
- Computes diversity using OpenAI embeddings
- Uses greedy algorithm to select k trajectories maximizing F(S) = (1-λ)*Fidelity + λ*Diversity

Example usage:
    python scripts/obtain_objective_summary.py --config_file configs/obj-summary/PQA2A3.yaml
"""

import os

import sys
import argparse
import json
import pickle
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Optional
from tqdm import tqdm
from datetime import datetime

# Add project root to path
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
    sys.path.append(project_root)

from src.objective_summary import (
    TrajectoryData,
    get_embeddings_batch,
    compute_kmeans_partitions,
    compute_normalization_constants,
    compute_cluster_diversity,
    greedy_select_trajectories,
    save_objective_summary,
    compute_objective_function,
    load_global_trend_from_results_dir,
    compute_total_fidelity
)
from src.objective_scorer import ObjectiveScorer
from src.model_generation import generate_responses_batched
from src.setup_datasets_clean import load_evaluation_dataset, convert_sample_to_prompt
import src.constants as constants


def parse_args():
    parser = argparse.ArgumentParser(description='Obtain objective summary trajectories')

    # Config file
    parser.add_argument('--config_file', type=str, default=None,
                       help='Path to yaml/json config file')

    # Experiment identification
    parser.add_argument('--experiment_name', type=str, default='obj_summary_exp',
                       help='Name for this experiment')
    parser.add_argument('--save_dir', type=str, default='',
                       help='Directory to save results')

    # Model checkpoints
    parser.add_argument('--model_paths', type=str, default='',
                       help='JSON list of model checkpoint paths in order [t1, t2, ..., tT]')

    # Objectives discovery results directory
    parser.add_argument('--objectives_results_dir', type=str, default=None,
                       help='Path to objectives discovery results directory containing trend_plots_* subdirs with final_trend_params.json. If provided, loads global trend for fidelity calculation.')

    # Cache directory for rubrics/examples/descriptions
    parser.add_argument('--cache_dir', type=str, default=None,
                       help='Path to cache directory containing rubrics_cache.json, examples_cache.json, descriptions_cache.json. Usually the same as objectives_results_dir.')

    # Dataset parameters
    parser.add_argument('--dataset_name', type=str, default='local_alpaca_gpt4_10k',
                       help='Name of evaluation dataset')
    parser.add_argument('--dataset_split', type=str, default='test',
                       help='Dataset split to use')
    parser.add_argument('--num_candidate_samples', type=int, default=100,
                       help='Number of candidate prompts to randomly sample from dataset')

    # Objective to summarize
    parser.add_argument('--objective_name', type=str, default='helpfulness',
                       help='Name of the objective to summarize')

    # Scoring parameters
    parser.add_argument('--scorer_use_api', action='store_true', default=True,
                       help='Use API for objective scoring')
    parser.add_argument('--scorer_model_name', type=str, default='gpt-4o-mini',
                       help='Model name for scoring')
    parser.add_argument('--normalize_scores', type=bool, default=True,
                       help='Normalize scores to [0, 1] range')

    # Greedy algorithm parameters
    parser.add_argument('--num_summary_samples', type=int, default=5,
                       help='Number of trajectories to select (k)')
    parser.add_argument('--lambda_weight', type=float, default=0.5,
                       help='Weight for diversity term in convex combination F(S) = (1-lambda)*sal + lambda*div')
    parser.add_argument('--num_clusters', type=int, default=5,
                       help='Number of K-Means clusters (m) for diversity computation')

    # Generation parameters
    parser.add_argument('--batch_size', type=int, default=8,
                       help='Batch size for generation')
    parser.add_argument('--max_new_tokens', type=int, default=512,
                       help='Maximum new tokens to generate')
    parser.add_argument('--temperature', type=float, default=0.7,
                       help='Sampling temperature')
    parser.add_argument('--top_p', type=float, default=0.9,
                       help='Top-p sampling parameter')

    # Device and memory
    parser.add_argument('--device', type=str, default='cuda',
                       help='Device to use')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')

    # Embedding parameters
    parser.add_argument('--embedding_model', type=str, default='text-embedding-3-small',
                       help='OpenAI embedding model to use for diversity')

    args = parser.parse_args()

    # Load config file if provided
    if args.config_file:
        with open(args.config_file) as f:
            if args.config_file.endswith('.yaml') or args.config_file.endswith('.yml'):
                import yaml
                config = yaml.safe_load(f)
            else:
                config = json.load(f)

        # Update args with config values
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

    # Parse JSON strings
    if isinstance(args.model_paths, str) and args.model_paths:
        args.model_paths = json.loads(args.model_paths)

    return args


def sample_candidate_prompts(
    dataset_name: str,
    dataset_split: str,
    num_samples: int,
    seed: int
) -> List[Dict]:
    """
    Randomly sample candidate prompts from the dataset.

    Args:
        dataset_name: Name of the dataset
        dataset_split: Split to use
        num_samples: Number of samples to draw
        seed: Random seed

    Returns:
        List of dictionaries with 'idx', 'prompt_messages', and 'prompt_text'
    """
    print(f"\nLoading dataset: {dataset_name} ({dataset_split} split)")
    dataset = load_evaluation_dataset(dataset_name, dataset_split)

    # Random sampling
    np.random.seed(seed)
    num_samples = min(num_samples, len(dataset))
    sampled_indices = np.random.choice(len(dataset), size=num_samples, replace=False)

    samples = []
    for idx in sampled_indices:
        sample = dataset[int(idx)]
        prompt_messages = convert_sample_to_prompt(sample, dataset_name, multi_turn=False)
        prompt_text = prompt_messages[0]['content'] if prompt_messages else ""

        samples.append({
            'idx': int(idx),
            'prompt_messages': prompt_messages,
            'prompt_text': prompt_text
        })

    print(f"Sampled {len(samples)} candidate prompts")
    return samples


def generate_trajectory_responses(
    model_paths: List[str],
    prompts: List[Dict],
    batch_size: int,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    device: str
) -> Dict[int, List[str]]:
    """
    Generate responses from each model checkpoint for all prompts.

    Args:
        model_paths: List of model checkpoint paths [t1, t2, ..., tT]
        prompts: List of prompt dictionaries
        batch_size: Batch size for generation
        max_new_tokens: Max tokens to generate
        temperature: Sampling temperature
        top_p: Top-p value
        device: Device to use

    Returns:
        Dictionary mapping prompt index to list of responses (one per checkpoint)
    """
    # Initialize: prompt_idx -> [response_t1, response_t2, ..., response_tT]
    trajectories = {p['idx']: [] for p in prompts}
    prompt_messages_list = [p['prompt_messages'] for p in prompts]

    for t, model_path in enumerate(model_paths):
        print(f"\n{'='*60}")
        print(f"Generating responses from checkpoint {t+1}/{len(model_paths)}")
        print(f"Model: {os.path.basename(model_path)}")
        print(f"{'='*60}")

        responses = generate_responses_batched(
            model_path=model_path,
            prompts=prompt_messages_list,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            temperature=temperature,
            top_p=top_p,
            device=device,
            use_cache=True
        )

        # Store responses for each prompt
        for i, prompt in enumerate(prompts):
            trajectories[prompt['idx']].append(responses[i])

    return trajectories


def score_trajectories(
    prompts: List[Dict],
    trajectories: Dict[int, List[str]],
    objective_name: str,
    scorer: ObjectiveScorer
) -> Dict[int, List[float]]:
    """
    Score all trajectory responses for the given objective.

    Args:
        prompts: List of prompt dictionaries
        trajectories: Dictionary mapping prompt index to list of responses
        objective_name: Name of objective to score
        scorer: ObjectiveScorer instance

    Returns:
        Dictionary mapping prompt index to list of scores
    """
    print(f"\n{'='*60}")
    print(f"Scoring trajectories for objective: {objective_name}")
    print(f"{'='*60}")

    scores = {}

    for prompt in tqdm(prompts, desc="Scoring prompts"):
        idx = prompt['idx']
        prompt_text = prompt['prompt_text']
        responses = trajectories[idx]

        # Score each response individually
        trajectory_scores = []
        for response in responses:
            score = scorer.score_single_objective(
                query=prompt_text,
                response=response,
                objective=objective_name
            )
            trajectory_scores.append(score)

        scores[idx] = trajectory_scores

    return scores


def compute_embeddings(
    prompts: List[Dict],
    trajectories: Dict[int, List[str]],
    embedding_model: str
) -> Dict[int, np.ndarray]:
    """
    Compute embeddings for each trajectory (prompt + responses concatenated).

    Args:
        prompts: List of prompt dictionaries
        trajectories: Dictionary mapping prompt index to list of responses
        embedding_model: OpenAI embedding model name

    Returns:
        Dictionary mapping prompt index to embedding array
    """
    print(f"\n{'='*60}")
    print(f"Computing embeddings using {embedding_model}")
    print(f"{'='*60}")

    # Concatenate prompt and all responses for each trajectory
    texts_to_embed = []
    indices = []

    for prompt in prompts:
        idx = prompt['idx']
        prompt_text = prompt['prompt_text']
        responses = trajectories[idx]

        # Concatenate prompt and responses
        combined_text = f"PROMPT: {prompt_text}\n\n"
        for i, response in enumerate(responses):
            combined_text += f"RESPONSE {i+1}: {response}\n\n"

        texts_to_embed.append(combined_text)
        indices.append(idx)

    # Get embeddings in batch
    embeddings_list = get_embeddings_batch(texts_to_embed, model=embedding_model)

    # Map back to prompt indices
    embeddings = {idx: emb for idx, emb in zip(indices, embeddings_list)}

    print(f"Computed {len(embeddings)} embeddings")
    return embeddings


def build_trajectory_data(
    prompts: List[Dict],
    trajectories: Dict[int, List[str]],
    scores: Dict[int, List[float]],
    embeddings: Dict[int, np.ndarray],
    global_trend_func=None,
    global_trend_params=None
) -> List[TrajectoryData]:
    """
    Build TrajectoryData objects from all components.

    Args:
        prompts: List of prompt dictionaries
        trajectories: Dictionary mapping prompt index to responses
        scores: Dictionary mapping prompt index to scores
        embeddings: Dictionary mapping prompt index to embeddings
        global_trend_func: Optional global trend function f*(t) for fidelity calculation
        global_trend_params: Optional parameters for the global trend function

    Returns:
        List of TrajectoryData objects
    """
    trajectory_data_list = []

    for prompt in prompts:
        idx = prompt['idx']

        data = TrajectoryData(
            prompt_idx=idx,
            prompt=prompt['prompt_text'],
            responses=trajectories[idx],
            scores=scores[idx],
            embedding=embeddings.get(idx),
            global_trend_func=global_trend_func,
            global_trend_params=global_trend_params
        )
        trajectory_data_list.append(data)

    return trajectory_data_list


def plot_selection_history(
    selection_history: List[Dict],
    save_path: str,
    lambda_weight: float
) -> None:
    """
    Plot the normalized fidelity, diversity, and total scores per iteration.

    Args:
        selection_history: List of selection history dictionaries
        save_path: Path to save the plot
        lambda_weight: Lambda weight used
    """
    if not selection_history:
        print("No selection history to plot")
        return

    iterations = [h['iteration'] for h in selection_history]
    total_fidelity = [h['total_fidelity'] for h in selection_history]  # Now normalized
    total_diversity = [h['total_diversity'] for h in selection_history]  # Now normalized
    total_f = [h['total_f'] for h in selection_history]

    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot lines
    ax.plot(iterations, total_fidelity, 'o-', label='Fid_norm', color='#2ca02c', linewidth=2, markersize=8)
    ax.plot(iterations, total_diversity, 's-', label='Div_norm', color='#1f77b4', linewidth=2, markersize=8)
    ax.plot(iterations, total_f, '^-', label=f'F(S) = (1-{lambda_weight})*Fid + {lambda_weight}*Div', color='#d62728', linewidth=2, markersize=8)

    ax.set_xlabel('Iteration', fontsize=12, fontweight='bold')
    ax.set_ylabel('Normalized Score', fontsize=12, fontweight='bold')
    ax.set_title('Greedy Selection Progress (Normalized)', fontsize=14, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xticks(iterations)
    ax.set_ylim(0, 1.05)  # Normalized values are in [0, 1]

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Selection history plot saved to: {save_path}")

    # Also save as PDF
    pdf_path = save_path.replace('.png', '.pdf')
    plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
    print(f"PDF saved to: {pdf_path}")

    plt.close()


def main():
    args = parse_args()

    # Validate inputs
    if not args.model_paths:
        raise ValueError("model_paths must be provided")

    # Create save directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    os.makedirs(args.save_dir, exist_ok=True)
    experiment_dir = os.path.join(args.save_dir, f"ObjSummary-{args.experiment_name}-{timestamp}")
    os.makedirs(experiment_dir, exist_ok=True)

    # Save config for reproducibility
    config_save_path = os.path.join(experiment_dir, 'config.json')
    with open(config_save_path, 'w') as f:
        json.dump(vars(args), f, indent=2, default=str)
    print(f"Config saved to: {config_save_path}")

    # Get checkpoint names for labeling
    checkpoint_names = [os.path.basename(p) for p in args.model_paths]

    # Load global trend from objectives results directory (if provided)
    global_trend_func = None
    global_trend_params = None
    best_trend_type = None
    if args.objectives_results_dir:
        print("\n" + "="*70)
        print("LOADING GLOBAL TREND FROM OBJECTIVES RESULTS")
        print("="*70)
        global_trend_func, global_trend_params, best_trend_type, _ = load_global_trend_from_results_dir(
            args.objectives_results_dir,
            args.objective_name
        )
        if global_trend_func is None:
            print("Warning: Could not load global trend. Fidelity will fall back to legacy salience calculation.")

    print("\n" + "="*70)
    print("OBJECTIVE SUMMARY GENERATION")
    print("="*70)
    print(f"Objective: {args.objective_name}")
    print(f"Number of checkpoints: {len(args.model_paths)}")
    print(f"Candidate samples: {args.num_candidate_samples}")
    print(f"Summary samples (k): {args.num_summary_samples}")
    print(f"Lambda weight: {args.lambda_weight}")
    print(f"Number of K-Means clusters: {args.num_clusters}")
    print(f"Global trend type: {best_trend_type if best_trend_type else 'Not loaded (using legacy salience)'}")
    print(f"Objective: F(S) = (1 - lambda) * f_fid(S) + lambda * f_div(S)")
    print(f"Fidelity: fid(ξ) = exp(- Σ(u_t - f*(t))^2)")

    # Step 1: Sample candidate prompts
    print("\n" + "="*70)
    print("STEP 1: Sampling candidate prompts")
    print("="*70)
    prompts = sample_candidate_prompts(
        dataset_name=args.dataset_name,
        dataset_split=args.dataset_split,
        num_samples=args.num_candidate_samples,
        seed=args.seed
    )

    # Step 2: Generate responses from each checkpoint
    print("\n" + "="*70)
    print("STEP 2: Generating trajectory responses")
    print("="*70)
    trajectories = generate_trajectory_responses(
        model_paths=args.model_paths,
        prompts=prompts,
        batch_size=args.batch_size,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        device=args.device
    )

    # Step 3: Score trajectories
    print("\n" + "="*70)
    print("STEP 3: Scoring trajectories")
    print("="*70)
    scorer = ObjectiveScorer(
        use_api=args.scorer_use_api,
        model_name=args.scorer_model_name,
        dataset_type=constants.DATASET_NAMES_DICT.get(args.dataset_name, 'hh'),
        normalize_scores=args.normalize_scores,
        use_detailed_rubric=True,
        cache_dir=args.cache_dir,
        save_dir=experiment_dir
    )
    scores = score_trajectories(prompts, trajectories, args.objective_name, scorer)

    # Step 4: Compute embeddings for diversity
    print("\n" + "="*70)
    print("STEP 4: Computing embeddings")
    print("="*70)
    embeddings = compute_embeddings(prompts, trajectories, args.embedding_model)

    # Step 5: Build TrajectoryData objects
    print("\n" + "="*70)
    print("STEP 5: Building trajectory data")
    print("="*70)
    trajectory_data_list = build_trajectory_data(
        prompts, trajectories, scores, embeddings,
        global_trend_func=global_trend_func,
        global_trend_params=global_trend_params
    )
    print(f"Built {len(trajectory_data_list)} trajectory objects")

    # Print fidelity statistics
    fidelities = [t.fidelity for t in trajectory_data_list]
    print(f"Fidelity statistics: min={min(fidelities):.4f}, max={max(fidelities):.4f}, "
          f"mean={np.mean(fidelities):.4f}, std={np.std(fidelities):.4f}")

    # Step 6: Compute K-Means partitions for diversity
    print("\n" + "="*70)
    print("STEP 6: Computing K-Means partitions")
    print("="*70)
    trajectory_data_list, actual_num_clusters = compute_kmeans_partitions(
        trajectory_data_list,
        num_clusters=args.num_clusters,
        random_state=args.seed
    )
    print(f"Partitioned {len(trajectory_data_list)} trajectories into {actual_num_clusters} clusters")

    # Print cluster distribution
    cluster_counts = {}
    for t in trajectory_data_list:
        if t.cluster_label is not None:
            cluster_counts[t.cluster_label] = cluster_counts.get(t.cluster_label, 0) + 1
    print(f"Cluster distribution: {dict(sorted(cluster_counts.items()))}")

    # Compute normalization constants
    fid_normalizer, div_normalizer = compute_normalization_constants(
        trajectory_data_list, actual_num_clusters
    )
    print(f"\nNormalization constants:")
    print(f"  fid_max (full set fidelity): {fid_normalizer:.4f}")
    print(f"  div_max (full set diversity): {div_normalizer:.4f}")

    # Step 7: Run greedy selection algorithm
    print("\n" + "="*70)
    print("STEP 7: Running greedy selection algorithm")
    print("="*70)
    print(f"Objective: F(S) = (1-λ) * (fid/fid_max) + λ * (div/div_max), λ={args.lambda_weight}")
    selected_trajectories, selection_history = greedy_select_trajectories(
        all_trajectories=trajectory_data_list,
        k=args.num_summary_samples,
        lambda_weight=args.lambda_weight,
        num_clusters=actual_num_clusters,
        fid_normalizer=fid_normalizer,
        div_normalizer=div_normalizer,
        verbose=True
    )

    # Step 8: Save results
    print("\n" + "="*70)
    print("STEP 8: Saving results")
    print("="*70)

    # Save objective summary text file
    summary_path = os.path.join(experiment_dir, f'objective_summary_{args.objective_name}.txt')
    save_objective_summary(
        selected_trajectories=selected_trajectories,
        selection_history=selection_history,
        objective_name=args.objective_name,
        output_path=summary_path,
        checkpoint_names=checkpoint_names,
        lambda_weight=args.lambda_weight,
        num_clusters=actual_num_clusters,
        fid_normalizer=fid_normalizer,
        div_normalizer=div_normalizer,
        best_trend_type=best_trend_type
    )

    # Save selection history plot
    plot_path = os.path.join(experiment_dir, 'selection_history_plot.png')
    plot_selection_history(selection_history, plot_path, args.lambda_weight)

    # Save raw data as pickle
    raw_data = {
        'prompts': prompts,
        'trajectories': trajectories,
        'scores': scores,
        'embeddings': {k: v.tolist() for k, v in embeddings.items()},
        'num_clusters': actual_num_clusters,
        'fid_normalizer': fid_normalizer,
        'div_normalizer': div_normalizer,
        'best_trend_type': best_trend_type,
        'global_trend_params': global_trend_params.tolist() if global_trend_params is not None else None,
        'objectives_results_dir': args.objectives_results_dir,
        'trajectory_data_list': [
            {
                'prompt_idx': t.prompt_idx,
                'prompt': t.prompt,
                'responses': t.responses,
                'scores': t.scores,
                'fidelity': t.fidelity,
                'cluster_label': t.cluster_label
            }
            for t in trajectory_data_list
        ],
        'selected_indices': [t.prompt_idx for t in selected_trajectories],
        'selection_history': selection_history,
        'args': vars(args)
    }
    pkl_path = os.path.join(experiment_dir, 'raw_data.pkl')
    with open(pkl_path, 'wb') as f:
        pickle.dump(raw_data, f)
    print(f"Raw data saved to: {pkl_path}")

    # Print final summary
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print(f"Objective: {args.objective_name}")
    print(f"Global trend type: {best_trend_type if best_trend_type else 'Not loaded (using legacy salience)'}")
    print(f"K-Means clusters: {actual_num_clusters}")
    print(f"Normalization: fid_max={fid_normalizer:.4f}, div_max={div_normalizer:.4f}")
    print(f"Selected {len(selected_trajectories)} trajectories from {len(trajectory_data_list)} candidates")
    print(f"\nSelected trajectories (prompt_idx, cluster_label, fidelity):")
    for t in selected_trajectories:
        print(f"  - Prompt {t.prompt_idx}, Cluster {t.cluster_label}, Fidelity {t.fidelity:.4f}")
    if selection_history:
        print(f"\nFinal normalized values:")
        print(f"  F(S): {selection_history[-1]['total_f']:.4f}")
        print(f"  Fid_norm: {selection_history[-1]['total_fidelity']:.4f}")
        print(f"  Div_norm: {selection_history[-1]['total_diversity']:.4f}")
    print(f"\nAll results saved to: {experiment_dir}")


if __name__ == '__main__':
    main()
