import os
import pandas as pd
import argparse
import jsonlines
from tts_contextual_bandit import TTSContextualBandit
from reward_function import reward_function
from config import LLM_MODELS, CACHE_DIR, EMBEDDING_FILE, LOSS_LOG_PATH

def main():
    # Add command line argument parsing
    parser = argparse.ArgumentParser(description='Run Contextual Bandit Experiment')
    parser.add_argument('--algorithm', type=str, default='lin_ucb', 
                        choices=['neural_ucb', 'lin_ucb', 'random', 'mixed_ucb', 'knn', 'oracle'],
                        help='Select algorithm: neural_ucb, lin_ucb, random, mixed_ucb, knn, oracle')
    parser.add_argument('--alpha', type=float, default=1.0, help='Alpha parameter for LinUCB')
    parser.add_argument('--beta', type=float, default=1.0, help='Beta parameter for NeuralUCB')
    parser.add_argument('--lambda_val', type=float, default=1.0, help='Regularization parameter lambda')
    parser.add_argument('--warm_up', type=int, default=0, help='Number of problems for warm-up phase (random selection)')
    parser.add_argument('--use_diag', action='store_true', default=False,
                        help='Enable diagonalization approximation (default: False, use full matrix)')
    parser.add_argument('--fusion_mode', type=str, default='average',
                        choices=['average', 'concat'],
                        help='Embedding fusion mode: average or concat')
    parser.add_argument('--fixed_action', type=str, default=None,
                        help='Fixed action to run (e.g., "qwen3-0.6b+qp1cp1bs1")')
    parser.add_argument('--use_mock_executor', action='store_true', help='Use mock executor for testing')
    parser.add_argument('--max_problems', type=int, default=None, help='Maximum number of problems to run')
    default_data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "test100.jsonl")
    parser.add_argument('--data_path', type=str, default=default_data_path,
                        help='Path to data file')
    parser.add_argument('--exp_label', type=str, default=None, help='Experiment label for file naming')
    parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate for NeuralUCB')
    parser.add_argument('--hidden_dims', type=int, nargs='+', default=[32, 32], help='Hidden dimensions for NeuralUCB (e.g. 32 32)')
    parser.add_argument('--local_training_iter', type=int, default=30, help='Number of local training iterations for NeuralUCB')
    parser.add_argument('--offline_embeddings', type=str, default=None, help='Path to offline query embeddings JSON')
    parser.add_argument('--offline_results', type=str, default=None, help='Path to offline execution results JSON')
    parser.add_argument('--reward_weights', type=float, nargs='+', default=None, help='Custom reward weights [w_acc, w_ver, w_cost, bias]')
    parser.add_argument('--allowed_action_space', type=str, default=None, help='Restrict action space: model_routing, routing_full, etc.')
    parser.add_argument('--virtual_dataset', action='store_true', help='Use virtual dataset logic (no physical file needed)')
    parser.add_argument('--cost_metric', type=str, default="Normalized_EFLOPS", help='Cost metric to use: Normalized_EFLOPS or Difficulty_Aware_Normalized_EFLOPS')
    parser.add_argument('--action_embedding_file', type=str, default="action_embeddings_origin.json", help='Path to action embeddings file')

    args = parser.parse_args()

    print("="*40)
    print("Arguments:")
    for arg in vars(args):
        print(f"  {arg}: {getattr(args, arg)}")
    print("="*40)
    
    # Load dataset
    problems = []
    
    if args.virtual_dataset:
        # Generate virtual problems with IDs mapped to AIME/MATH
        # 0-59 -> test_aime_question1..60
        # 60-209 -> test150_question1..150
        print("Using Virtual Dataset Mode.")
        
        if args.data_path and "combined_dataset_2100" in args.data_path:
             # Load from the prepared jsonl file
             with jsonlines.open(args.data_path) as reader:
                for obj in reader:
                    problems.append(obj)
             data_filename = "combined_dataset_2100"
        else:
            # Determine total count (default 210 if not specified)
            total_count = args.max_problems if (args.max_problems and args.max_problems > 0) else 210
            
            import random
            # Generate a shuffled sequence of indices
            # But for reproducibility, maybe keep it deterministic or use fixed seed?
            # User said: "randomly generated shuffled sequence"
            # Let's create a list of indices and shuffle it.
            indices = list(range(210)) # Fixed to 210 total available unique problems
            random.seed(42) # Set seed for reproducibility
            random.shuffle(indices)
            
            # Truncate if max_problems is less than 210
            if total_count < 210:
                 indices = indices[:total_count]
            
            for idx in indices:
                if 0 <= idx <= 59:
                    # AIME
                    # 0 -> test_aime_question1
                    # 59 -> test_aime_question60
                    q_id = f"test_aime_question{idx + 1}"
                elif 60 <= idx <= 209:
                    # MATH
                    # 60 -> test150_question1
                    # 209 -> test150_question150
                    math_idx = idx - 60 + 1
                    q_id = f"test150_question{math_idx}"
                else:
                    q_id = f"unknown_question{idx}"
                    
                problems.append({
                    'id': q_id,
                    'problem': f"Virtual Query {q_id}", # Dummy text, embedding lookup uses ID
                    'solution': None,
                    'answer': None # Ground truth will be fetched from offline results if needed? Or handled by offline exec.
                })
                
            # Mock data filename for logging
            data_filename = "virtual_dataset_210"
            
    elif not os.path.exists(args.data_path):
        raise FileNotFoundError(f"Data file not found: {args.data_path}")
    
    elif args.data_path.endswith('.parquet'):
        df = pd.read_parquet(args.data_path)
        # Adapt to old format
        for _, row in df.iterrows():
            problems.append({
                'problem': row['problem'],
                'solution': None # parquet may not have solution column, or need to confirm
            })
        data_filename = os.path.splitext(os.path.basename(args.data_path))[0]
    elif args.data_path.endswith('.jsonl'):
        with jsonlines.open(args.data_path) as reader:
            for obj in reader:
                problems.append(obj)
        data_filename = os.path.splitext(os.path.basename(args.data_path))[0]
    else:
         raise ValueError("Unsupported file format, please use .parquet or .jsonl")

    # Truncate problems if max_problems is specified (Only for non-virtual, virtual handled above)
    if not args.virtual_dataset and args.max_problems and args.max_problems > 0:
        problems = problems[:args.max_problems]
        print(f"Limiting to {args.max_problems} problems.")

    print(f"Successfully loaded {len(problems)} problems.")

    # Warmup steps configuration
    warmup_steps = args.warm_up
    if warmup_steps > 0:
        print(f"Warmup steps set to {warmup_steps}.")
    
    # Construct log directory and base filename
    # data_filename logic moved above to support virtual dataset
    
    # Get the directory where the script is located to ensure absolute path
    script_dir = os.path.dirname(os.path.abspath(__file__))
    log_dir = os.path.join(script_dir, "logs", data_filename)
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # os.makedirs(log_dir, exist_ok=True)
    
    # Build parameter string for filename
    # Include variable parameters (algorithm, alpha, beta, lambda, diag, fusion)
    # Even for fixed_action, we want to include algo params if they are used (e.g. routing within fixed set)
    
    # Shorten filename components
    diag_str = 'd' if args.use_diag else 'f' # d=diag, f=full
    fusion_str = args.fusion_mode[:3] # ave/con
    
    # base_param_str = f"{args.algorithm}_alpha{args.alpha}_beta{args.beta}_lambda{args.lambda_val}_{'diag' if args.use_diag else 'full'}_{args.fusion_mode}"
    # New short format: lin_a1.0_b1.0_l1.0_d_ave
    if args.algorithm == "lin_ucb":
        algo_short = "lin"
    elif args.algorithm == "neural_ucb":
        algo_short = "neu"
    elif args.algorithm == "mixed_ucb":
        algo_short = "mix"
    elif args.algorithm == "knn":
        algo_short = "knn"
    elif args.algorithm == "oracle":
        algo_short = "orc"
    else:
        algo_short = "rnd"
    
    # Include LR, Hidden Dims, and Train Iter in param_str to avoid overwriting logs
    lr_str = f"_lr{args.lr}" if args.lr is not None else ""
    if args.hidden_dims is not None:
        if len(args.hidden_dims) == 1: # Handle single value list from arg parser if it happens
             hd_str = f"_h{args.hidden_dims[0]}x{args.hidden_dims[0]}"
        else:
             hd_str = f"_h{'x'.join(map(str, args.hidden_dims))}"
    else:
        hd_str = ""
        
    iter_str = f"_iter{args.local_training_iter}" if args.local_training_iter is not None else ""
    
    base_param_str = f"{algo_short}_a{args.alpha}_b{args.beta}_l{args.lambda_val}{lr_str}{hd_str}{iter_str}_{diag_str}_{fusion_str}"
    
    if args.exp_label:
        label_safe = args.exp_label.replace(" ", "_")
        param_str = f"{label_safe}_{base_param_str}"
    elif args.fixed_action:
        # If fixed_action length > 8 (approx hash length), use hash
        if len(args.fixed_action) > 8:
             import hashlib
             hash_str = hashlib.md5(args.fixed_action.encode()).hexdigest()[:8]
             param_str = f"fix_{hash_str}_{base_param_str}"
        else:
             param_str = f"fix_{args.fixed_action}_{base_param_str}"
    else:
        param_str = base_param_str

    
    log_file = os.path.join(log_dir, f"bandit_process_{param_str}.log")
    csv_log_file = os.path.join(log_dir, f"loss_log_{param_str}.csv")
    
    print(f"Log will be written to: {log_file}")
    print(f"CSV Log will be written to: {csv_log_file}")
    
    bandit = TTSContextualBandit(
        llm_models=LLM_MODELS,
        reward_function=reward_function,
        algorithm=args.algorithm,
        embedding_file=args.action_embedding_file, # Pass from args
        cache_dir=CACHE_DIR,
        loss_log_path=csv_log_file,
        beta=args.beta,
        lambda_=args.lambda_val,
        lr=args.lr,
        hidden_dims=args.hidden_dims,
        alpha=args.alpha,
        use_diag=args.use_diag,  # Add diag parameter
        reset_theta_each_train=False,
        embedding_fusion_mode=args.fusion_mode,
        use_mock_executor=args.use_mock_executor,
        warmup_steps=warmup_steps,
        fixed_action=args.fixed_action,  # Pass fixed_action
        offline_embeddings_file=args.offline_embeddings,
        offline_results_file=args.offline_results,
        reward_weights=args.reward_weights,
        allowed_action_space=args.allowed_action_space,
        cost_metric=args.cost_metric  # Pass cost metric
    )

    with open(log_file, "w", encoding="utf-8") as f:
        diag_status = "Diagonal" if args.use_diag else "Full Matrix"
        if args.fixed_action:
            f.write(f"Bandit Processing Log - Algorithm: Fixed Action ({args.fixed_action})\n")
        else:
            f.write(f"Bandit Processing Log - Algorithm: {args.algorithm} - {diag_status}\n")
        f.write("=" * 80 + "\n")

    for i, item in enumerate(problems):
        query = item['problem']
        # Use 'answer' for comparison as it seems to be the correct extracted ground truth in test100.jsonl
        ground_truth_for_comparison = item.get('answer')
        
        if not ground_truth_for_comparison:
             print(f"Error: 'answer' field missing for problem {i+1}. Fallback mechanisms are disabled.")
             # Optionally, you might want to skip this problem or just log the error and continue with None
             # For now, let's log and set it to None, which will result in 0 accuracy
             ground_truth_for_comparison = None
             
        # But we also want to display the full solution in logs maybe?
        # Let's keep solution as the full text for logging, but pass extracted_groundtruth to process_query
        full_solution = item.get('solution')
        
        print(f"\n[Progress {i+1}/{len(problems)}] Processing problem...")
        print(f"[Problem Summary] {query[:60]}..." if len(query) > 60 else f"[Problem Summary] {query}")
        
        # Determine Query ID
        query_id = item.get('id')
        if not query_id:
             # Construct ID based on filename and index
             # e.g. test150_question1
             query_id = f"{data_filename}_question{i+1}"
        
        print(f"[Query ID] {query_id}")

        result = bandit.process_query(query, ground_truth=ground_truth_for_comparison, local_training_iter=args.local_training_iter, query_id=query_id)

        reward_details = result.get('reward_details', {})
        log_content = [
            f"\nProblem ID: {i+1}/{len(problems)} (QueryID: {query_id})",
            f"Algorithm: {args.algorithm}, Diagonal: {args.use_diag}",
            f"Selected Model: {result['model']}, QP: {result.get('qp', 'N/A')}, "
            f"CP: {result.get('cp', 'N/A')}, BS: {result.get('bs', 'N/A')}, Tokens: {result.get('tokens_used', 0)}",
            f"Reward: {result['reward']:.4f} (Acc={reward_details.get('accuracy', 0):.2f}*w{reward_details.get('weight_a', 0)}, "
            f"Ver={reward_details.get('verifier_score', 0):.2f}*w{reward_details.get('weight_b', 0)}, "
            f"eFLOPs={reward_details.get('eflops', 0):.2e}*w{reward_details.get('weight_c', 0)})",
            f"eFLOPs_raw: {reward_details.get('eflops', 0):.4e}, Normalized eFLOPs: {reward_details.get('normalized_eflops', 0):.4f}, CSV_Line: {result.get('csv_line_index', -1)}",
            f"Avg Reward: {result['avg_reward']:.4f}, Training Loss: {result['training_loss']}, Cache Hit: {result['from_cache']}",
            f"Answer: {result.get('extracted_answer', 'N/A')} | Ground Truth: {ground_truth_for_comparison}",
            "-" * 80
        ]
        with open(log_file, "a", encoding="utf-8") as f:
            f.write("\n".join(log_content))

        print(f"[{args.algorithm}|{'diag' if args.use_diag else 'full'}] Processed {i+1}/{len(problems)}, Current Avg Reward: {result['avg_reward']:.2f}")

    bandit.save_all_caches()
    print("Experiment finished.")

if __name__ == "__main__":
    main()