import os
import json
import time
import torch
import numpy as np
from collections import deque
import random
import atexit
from openai import OpenAI

from config import GLOBAL_EMBEDDING_DIM, OPENAI_CONFIG, LLM_MODELS, TTS_COMBINATIONS, VERIFIER_MODEL
from cache_utils import load_cache, save_cache
from neural_ucb import NeuralUCB
from mixed_ucb import MixedUCB
from lin_ucb import LinUCB
from random_baseline import RandomBaseline
from knn_bandit import KNNBandit
from reward_function import reward_function
from verify_utils import grade_answer
import requests


# --------------------------------------------------------------

class TTSContextualBandit:
    def __init__(self, llm_models, reward_function,
                 algorithm='neural_ucb',  # New: Algorithm selection parameter
                 embedding_file="action_embeddings_origin.json",
                 cache_dir="cache",
                 loss_log_path="cache/loss_log.csv",
                 hidden_dims=[32, 32],
                 buffer_size=5000,
                 lr=0.0005,
                 beta=1.0,
                 lambda_=1.0,
                 alpha=1.0,  # New: LinUCB parameter
                 use_diag=True,
                 reset_theta_each_train=False,
                 embedding_fusion_mode='average',
                 use_mock_executor=False,
                 warmup_steps=0,
                 fixed_action=None,
                 offline_embeddings_file=None,  # New: Offline embeddings file
                 offline_results_file=None,    # New: Offline results file
                 reward_weights=None,          # New: Custom reward weights [w_acc, w_ver, w_cost, bias]
                 allowed_action_space=None,    # New: Restrict action space
                 cost_metric="Normalized_EFLOPS"): # New: Select cost metric
        
        self.llm_models = llm_models
        self.reward_function = reward_function
        self.embedding_fusion_mode = embedding_fusion_mode
        self.algorithm_name = algorithm  # Save algorithm name for logging
        self.use_mock_executor = use_mock_executor # Save config
        self.warmup_steps = warmup_steps # New: Save warmup steps
        self.fixed_action = fixed_action # New: Save fixed action
        self.lr = lr # Save learning rate for training
        self.reward_weights = reward_weights # Save custom weights
        self.allowed_action_space = allowed_action_space # Save allowed action space config
        self.cost_metric = cost_metric # Save cost metric
        
        print(f"[Bandit Init] Algorithm: {algorithm}, Warmup Steps: {warmup_steps}, Fixed Action: {fixed_action}")
        print(f"[Bandit Init] Cost Metric: {self.cost_metric}")
        if self.reward_weights:
            print(f"[Bandit Init] Custom Reward Weights: {self.reward_weights}")
        if self.allowed_action_space:
             print(f"[Bandit Init] Restricted Action Space: {self.allowed_action_space}")
        
        # Load offline data if provided
        self.offline_embeddings = {}
        self.offline_results = {}
        if offline_embeddings_file:
            self._load_offline_embeddings(offline_embeddings_file)
        if offline_results_file:
            self._load_offline_results(offline_results_file)

        # Initialize client
        self.client = OpenAI(**OPENAI_CONFIG)

        
        # Set cache
        os.makedirs(cache_dir, exist_ok=True)
        self.query_cache_path = os.path.join(cache_dir, "query_cache.json")
        self.loss_log_path = loss_log_path
        self.query_cache = load_cache(self.query_cache_path)
        
        self.cache_update_count = 0
        self.cache_save_interval = 20
        
        atexit.register(self.save_all_caches)
        
        # Load action embeddings
        self.action_embeddings, self.action_descriptions = self._load_embeddings(embedding_file)
        self.tts_combinations = TTS_COMBINATIONS
        
        # Parse fixed_action if provided
        self.allowed_action_indices = None
        if self.fixed_action:
            self.allowed_action_indices = self._parse_fixed_action(self.fixed_action)
        elif self.allowed_action_space:
             self.allowed_action_indices = self._parse_allowed_action_space(self.allowed_action_space)

        
        # Calculate input dimension
        if self.embedding_fusion_mode == 'concat':
            model_input_dim = GLOBAL_EMBEDDING_DIM * 2
        else:
            model_input_dim = GLOBAL_EMBEDDING_DIM
            
        # Initialize corresponding algorithm based on selection
        if algorithm == 'neural_ucb':
            self.bandit_algorithm = NeuralUCB(
                input_dim=model_input_dim,
                hidden_dims=hidden_dims,
                lambda_=lambda_,
                beta=beta,
                use_diag_z=use_diag,
                reset_theta_each_train=reset_theta_each_train
            )
        elif algorithm == 'mixed_ucb':
            self.bandit_algorithm = MixedUCB(
                input_dim=model_input_dim,
                hidden_dims=hidden_dims,
                lambda_=lambda_,
                beta=beta,
                use_diag_z=use_diag,
                reset_theta_each_train=reset_theta_each_train
            )
        elif algorithm == 'lin_ucb':
            self.bandit_algorithm = LinUCB(
                input_dim=model_input_dim,
                alpha=alpha,
                lambda_=lambda_,
                use_diag=use_diag  # Add diag parameter
            )
        elif algorithm == 'random':
            self.bandit_algorithm = RandomBaseline(
                input_dim=model_input_dim
            )
        elif algorithm == 'knn':
            self.bandit_algorithm = KNNBandit(
                k=5,
                max_history=buffer_size
            )
        elif algorithm == 'oracle':
            self.bandit_algorithm = RandomBaseline(input_dim=model_input_dim)
        else:
            raise ValueError(f"Unknown algorithm: {algorithm}")
        
        self.buffer = deque(maxlen=buffer_size)
        self.total_queries = 0
        self.cumulative_reward = 0
        
        os.makedirs(os.path.dirname(loss_log_path), exist_ok=True)
        self._loss_log_step = 0
        # If file does not exist, write header
        if not os.path.exists(self.loss_log_path):
            with open(self.loss_log_path, "w", encoding="utf-8") as f:
                f.write("step,model,qp,cp,bs,verifier_score,is_correct,reward,regret,optimal_reward,loss,accuracy,eflops,normalized_eflops,weight_a,weight_b,weight_c,duration,csv_line_index\n")
        
        # Initialize UCB log path with experiment parameters
        loss_log_name = os.path.basename(loss_log_path)
        if loss_log_name.startswith("loss_log_"):
            # Extract params part from loss_log_filename
            # Example: loss_log_neural_ucb_beta1.0.csv -> neural_ucb_beta1.0
            params_part = loss_log_name[len("loss_log_"):-4] # remove prefix and .csv extension
            ucb_log_name = f"ucb_log_{params_part}.csv"
        else:
            # Fallback if naming convention doesn't match
            ucb_log_name = "ucb_log.csv"
            
        self.ucb_log_path = os.path.join(os.path.dirname(self.loss_log_path), ucb_log_name)
        
        # Overwrite existing UCB log file
        with open(self.ucb_log_path, "w", encoding="utf-8") as f:
            f.write("step,action_idx,model,qp,cp,bs,pred,bonus,ucb\n")

    def _load_offline_embeddings(self, file_path):
        """
        Load offline query embeddings.
        Format: JSON dict { "query_id": [float, ...] }
        """
        print(f"[Offline Data] Loading embeddings from {file_path}...")
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                self.offline_embeddings = json.load(f)
            print(f"[Offline Data] Loaded {len(self.offline_embeddings)} query embeddings.")
        except Exception as e:
            print(f"[Offline Data] Error loading embeddings: {e}")

    def _load_offline_results(self, file_path):
        """
        Load offline execution results.
        Supports JSON or CSV format.
        
        CSV Format:
        Dataset,Question_ID,Model,QP,CP,BS,Correctness,Verifier_Score,Memory_Access,Computation,EFLOPs,Total_Tokens,Noramlized_EFLOPS
        
        Target Internal Structure (self.offline_results):
        { 
            "query_id": { 
                "action_str": [ 
                    { "accuracy": float, "verifier_score": float, "eflops": float, "token_len": int, ... },
                    ... (trials)
                ],
                ...
            }
        }
        """
        print(f"[Offline Data] Loading results from {file_path}...")
        try:
            if file_path.endswith('.csv'):
                self._load_offline_results_csv(file_path)
            else:
                with open(file_path, 'r', encoding='utf-8') as f:
                    self.offline_results = json.load(f)
            print(f"[Offline Data] Loaded results for {len(self.offline_results)} queries.")
        except Exception as e:
            print(f"[Offline Data] Error loading results: {e}")

    def _load_offline_results_csv(self, file_path):
        import csv
        self.offline_results = {}
        
        # Mapping model names from CSV to internal names if necessary
        # CSV: Qwen3-4B -> Internal: qwen3-4b
        
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            count = 0
            for row in reader:
                # Construct Query ID
                # CSV has 'Dataset' and 'Question_ID' (integer)
                # We need to match main.py logic: f"{dataset_name}_question{id}"
                # Assumption: Dataset column (e.g., "AIME") maps to dataset filename prefix (e.g., "test_aime" or "test150")
                # But main.py uses filename. Let's try to map "AIME" -> "test_aime" and "MATH" -> "test150" or similar?
                # Or just use the constructed ID passed from main.py and hope it matches.
                
                # To be robust, let's look at row['Dataset'].
                # Example: Dataset="AIME", Question_ID="0" -> "test_aime_question1" (if 0-indexed in CSV but 1-indexed in main.py)
                # Let's assume CSV is 0-indexed based on "Question_ID,0" in preview.
                # And main.py uses 1-based index for ID generation.
                
                dataset_name = row['Dataset']
                q_id_num = int(row['Question_ID'])
                
                if dataset_name == "AIME":
                    query_prefix = "test_aime"
                elif dataset_name == "MATH-150" or dataset_name == "MATH-500": 
                    query_prefix = "test150" 
                else:
                    query_prefix = dataset_name.lower() # Fallback
                
                # Construct query_id to match main.py (1-based index)
                # CSV Question_ID seems to be 0-based from the snippet "AIME,0"
                query_id = f"{query_prefix}_question{q_id_num + 1}"
                
                if query_id not in self.offline_results:
                    self.offline_results[query_id] = {}
                
                # Construct Action Key
                # CSV Model: Qwen3-4B -> qwen3-4b
                model = row['Model'].lower()
                qp = int(row['QP'])
                cp = int(row['CP'])
                bs = int(row['BS'])
                
                action_key = f"{model}+qp{qp}cp{cp}bs{bs}"
                
                if action_key not in self.offline_results[query_id]:
                    self.offline_results[query_id][action_key] = []
                
                # Extract Metrics with safety check
                def safe_float(val, default=0.0):
                    try:
                        if not val or val == '': return default
                        return float(val)
                    except ValueError:
                        return default

                accuracy = safe_float(row.get('Correctness'))
                verifier_score = safe_float(row.get('Verifier_Score'))
                eflops = safe_float(row.get('EFLOPs'))
                token_len = int(safe_float(row.get('Total_Tokens')))
                
                # Extract Normalized Metrics if available
                normalized_eflops = safe_float(row.get('Normalized_EFLOPS'))
                difficulty_aware_eflops = safe_float(row.get('Difficulty_Aware_Normalized_EFLOPS'))

                # Store trial
                self.offline_results[query_id][action_key].append({
                    "accuracy": accuracy,
                    "verifier_score": verifier_score,
                    "eflops": eflops,
                    "token_len": token_len,
                    "normalized_eflops": normalized_eflops,
                    "difficulty_aware_eflops": difficulty_aware_eflops,
                    "csv_line_index": count + 2 # Header is line 1, first data row is line 2 (0-indexed count + 2)
                })
                count += 1
                
        print(f"[Offline Data] Processed {count} rows from CSV.")

    def save_all_caches(self):
        save_cache(self.query_cache_path, self.query_cache)
        print("All caches saved")
        
    def _append_loss_log(self, log_data):
        """
        Log detailed training logs
        log_data: dict containing fields to record
        """
        with open(self.loss_log_path, "a", encoding="utf-8") as f:
            # Write in order of header
            line = f"{self._loss_log_step}," \
                   f"{log_data.get('model', '')}," \
                   f"{log_data.get('qp', '')}," \
                   f"{log_data.get('cp', '')}," \
                   f"{log_data.get('bs', '')}," \
                   f"{log_data.get('verifier_score', '')}," \
                   f"{log_data.get('is_correct', '')}," \
                   f"{log_data.get('reward', '')}," \
                   f"{log_data.get('regret', '')}," \
                   f"{log_data.get('optimal_reward', '')}," \
                   f"{log_data.get('loss', '')}," \
                   f"{log_data.get('accuracy', '')}," \
                   f"{log_data.get('eflops', '')}," \
                   f"{log_data.get('normalized_eflops', '')}," \
                   f"{log_data.get('weight_a', '')}," \
                   f"{log_data.get('weight_b', '')}," \
                   f"{log_data.get('weight_c', '')}," \
                   f"{log_data.get('duration', '')}," \
                   f"{log_data.get('csv_line_index', '')}\n"
            f.write(line)
        self._loss_log_step += 1
        
    def _load_embeddings(self, file_path):
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Embedding file not found: {file_path}")
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        embeddings = np.array(data.get("embeddings", []), dtype=np.float32)
        descriptions = data.get("descriptions", [])
        if embeddings.size == 0:
            raise ValueError("No embeddings found in action_embeddings_origin.json or empty.")
        return embeddings, descriptions
        
    def _fusion_embedding(self, query_emb, action_emb):
        if self.embedding_fusion_mode == 'average':
            fused = (query_emb + action_emb) / 2
        elif self.embedding_fusion_mode == 'concat':
            fused = np.concatenate([query_emb, action_emb])
        else:
            raise ValueError(f"Unsupported fusion mode: {self.embedding_fusion_mode}")
        
        # Normalize the fused embedding
        norm = np.linalg.norm(fused)
        if norm > 1e-12:
            fused = fused / norm
            
        return fused.astype(np.float32)
            
    def get_embedding(self, text, query_id=None):
        if query_id and query_id in self.offline_embeddings:
            return np.array(self.offline_embeddings[query_id], dtype=np.float32)

        if text in self.query_cache:
            return self.query_cache[text]
            
        try:
            response = self.client.embeddings.create(
                input=text,
                model="text-embedding-v4",
                dimensions=GLOBAL_EMBEDDING_DIM,
                encoding_format="float"
            )
            embedding = response.data[0].embedding if hasattr(response, "data") else response["data"][0]["embedding"]
            self.query_cache[text] = embedding
            self.cache_update_count += 1
            if self.cache_update_count >= self.cache_save_interval:
                if save_cache(self.query_cache_path, self.query_cache):
                    self.cache_update_count = 0
            return np.array(embedding, dtype=np.float32)
        except Exception as e:
            print(f"Failed to get embedding: {e}")
            fallback = np.random.randn(GLOBAL_EMBEDDING_DIM).astype(np.float32).tolist()
            self.query_cache[text] = fallback
            return np.array(fallback, dtype=np.float32)
            
    def get_context_embedding(self, user_query, action_idx=None, query_id=None):
        query_embedding = np.array(self.get_embedding(user_query, query_id=query_id), dtype=np.float32)
        if action_idx is not None:
            action_embedding = self.action_embeddings[action_idx]
            fused = self._fusion_embedding(query_embedding, action_embedding)
            return fused.tolist()
        else:
            # If allowed_action_indices is set, only return embeddings for those actions?
            # However, returning full list maintains index consistency.
            # But the user asked: "check if input bandit model actions are only candidate actions"
            # If we return all embeddings here, select_action iterates over all.
            # This is fine as long as select_action handles filtering.
            return [self._fusion_embedding(query_embedding, a_emb).tolist() for a_emb in self.action_embeddings]
            
    def _append_ucb_log(self, ucb_details):
        """
        Log UCB details for actions in a step.
        Only log actions that were actually considered (not -inf).
        """
        with open(self.ucb_log_path, "a", encoding="utf-8") as f:
            for idx, detail in enumerate(ucb_details):
                # Filter out skipped actions (pred == -inf)
                if detail['pred'] == -float('inf'):
                    continue
                    
                model, (qp, cp, bs) = self.parse_action_idx(idx)
                line = f"{self._loss_log_step}," \
                       f"{idx}," \
                       f"{model}," \
                       f"{qp}," \
                       f"{cp}," \
                       f"{bs}," \
                       f"{detail['pred']}," \
                       f"{detail['bonus']}," \
                       f"{detail['ucb']}\n"
                f.write(line)

    def _parse_allowed_action_space(self, space_config):
        """
        Parse allowed action space configuration.
        space_config: dict or str
        Examples:
        - "model_routing": 0.6B (QP1CP1BS1) to 32B (QP1CP1BS1) (Routing_Full equivalent but restricted to param 1/1/1?)
          Wait, Model Routing usually means fixed parameters (e.g. qp1cp1bs1) across different models?
          Or is it just selecting between models with default params?
          Based on user request: "Model Routing, Routing_Full and 0.6B All TTS (Adaptive TTS) and Mixed"
        
        Let's implement specific presets:
        - "model_routing": All models, fixed params QP1 CP1 BS1.
        - "routing_full": All models, All params (Default, no restriction if None, but here we can return all indices).
        - "adaptive_tts_0.6b": Only 0.6B model, All params.
        - "mixed_0.6b_32b": 0.6B (All params) + 32B (QP1 CP1 BS1).
        """
        indices = []
        tts_count = len(self.tts_combinations)
        
        print(f"[Action Space] Parsing config: {space_config}")
        
        if space_config == "model_routing":
            # Select QP=1, CP=1, BS=1 for all models
            # Find index of (1, 1, 1) in tts_combinations
            target_tts_idx = -1
            for j, (qp, cp, bs) in enumerate(self.tts_combinations):
                if qp == 1 and cp == 1 and bs == 1:
                    target_tts_idx = j
                    break
            
            if target_tts_idx != -1:
                for i in range(len(self.llm_models)):
                    indices.append(i * tts_count + target_tts_idx)
            else:
                print("Warning: (1, 1, 1) not found in TTS combinations.")

        elif space_config == "routing_full":
            # All actions allowed
            indices = list(range(len(self.llm_models) * tts_count))

        elif space_config == "adaptive_tts_0.6b":
            # Only qwen3-0.6b, all params
            model_idx = -1
            for i, model in enumerate(self.llm_models):
                if "0.6b" in model:
                    model_idx = i
                    break
            
            if model_idx != -1:
                 for j in range(tts_count):
                     indices.append(model_idx * tts_count + j)
            else:
                 print("Warning: 0.6b model not found.")

        elif space_config == "adaptive_tts_4b":
            # Only qwen3-4b, all params
            model_idx = -1
            for i, model in enumerate(self.llm_models):
                if "4b" in model:
                    model_idx = i
                    break
            
            if model_idx != -1:
                 for j in range(tts_count):
                     indices.append(model_idx * tts_count + j)
            else:
                 print("Warning: 4b model not found.")

        elif space_config == "mixed_0.6b_32b":
            # 0.6B (All TTS) + 32B (QP1CP1BS1)
            
            # 1. Add 0.6B All
            model_idx_06 = -1
            for i, model in enumerate(self.llm_models):
                if "0.6b" in model:
                    model_idx_06 = i
                    break
            
            if model_idx_06 != -1:
                 for j in range(tts_count):
                     indices.append(model_idx_06 * tts_count + j)
            
            # 2. Add 32B QP1CP1BS1
            model_idx_32 = -1
            for i, model in enumerate(self.llm_models):
                if "32b" in model:
                    model_idx_32 = i
                    break
            
            target_tts_idx = -1
            for j, (qp, cp, bs) in enumerate(self.tts_combinations):
                if qp == 1 and cp == 1 and bs == 1:
                    target_tts_idx = j
                    break
            
            if model_idx_32 != -1 and target_tts_idx != -1:
                indices.append(model_idx_32 * tts_count + target_tts_idx)
        
        else:
             print(f"Warning: Unknown action space config: {space_config}. Allowing all actions.")
             indices = list(range(len(self.llm_models) * tts_count))
             
        indices = sorted(list(set(indices))) # Ensure unique and sorted
        print(f"[Action Space] Allowed {len(indices)} actions.")
        return indices

    def _parse_fixed_action(self, action_str):
        indices = []
        parts = action_str.split(',')
        import re
        
        tts_count = len(self.tts_combinations)
        
        for part in parts:
            part = part.strip()
            try:
                # Format: model+qpXcpYbsZ
                subparts = part.split('+')
                target_model = subparts[0]
                params_str = subparts[1]
                
                match = re.match(r'qp(\d+)cp(\d+)bs(\d+)', params_str)
                if not match:
                     raise ValueError(f"Invalid fixed action parameter format: {params_str}")
                
                target_qp = int(match.group(1))
                target_cp = int(match.group(2))
                target_bs = int(match.group(3))
                
                found_idx = -1
                for i, model in enumerate(self.llm_models):
                    if model == target_model:
                        for j, (qp, cp, bs) in enumerate(self.tts_combinations):
                            if qp == target_qp and cp == target_cp and bs == target_bs:
                                found_idx = i * tts_count + j
                                break
                    if found_idx != -1:
                        break
                
                if found_idx != -1:
                    indices.append(found_idx)
                else:
                    print(f"Warning: Fixed action component {part} not found in available actions.")
                    
            except Exception as e:
                print(f"Error parsing fixed action component '{part}': {e}.")
        
        if not indices:
            raise ValueError(f"No valid actions found in fixed_action: {action_str}")
            
        return indices

    def select_action(self, context_embeddings, query_id=None):
        """
        Select action based on UCB algorithm.
        If fixed_action is set:
            - If single action: always return that index.
            - If multiple actions: restrict selection to these indices.
        If warmup_steps > 0 and current step < warmup_steps, select random action (from allowed set).
        """
        # 1. Handle Fixed Action Case (Single Action Optimization)
        if self.allowed_action_indices and len(self.allowed_action_indices) == 1:
            idx = self.allowed_action_indices[0]
            print(f"[Action Selection] Fixed Action: {self.fixed_action} -> Index {idx}")
            return idx

        # Handle Oracle
        if self.algorithm_name == 'oracle':
            if not query_id or query_id not in self.offline_results:
                print(f"[Oracle] Warning: query_id {query_id} not found in offline results. Using random.")
                if self.allowed_action_indices:
                    return random.choice(self.allowed_action_indices)
                return random.randint(0, len(context_embeddings) - 1)
            
            # Find best action
            best_idx = -1
            best_reward = -float('inf')
            
            # Helper to calculate reward
            if self.reward_weights:
                w_acc, w_ver, w_cost, bias = self.reward_weights
            else:
                w_acc, w_ver, w_cost, bias = 0.3669, 0.3669, 0.2662, 0.2662
            
            query_results = self.offline_results[query_id]
            
            # Iterate over all possible actions
            for idx in range(len(context_embeddings)):
                # Skip if not allowed
                if self.allowed_action_indices is not None and idx not in self.allowed_action_indices:
                    continue
                    
                model, (qp, cp, bs) = self.parse_action_idx(idx)
                action_key = f"{model}+qp{qp}cp{cp}bs{bs}"
                
                # Check results
                trials = query_results.get(action_key)
                if not trials:
                     trials = query_results.get(str(idx)) # Fallback
                
                if not trials:
                    continue # Skip action with no data
                
                # Calculate expected reward (average of trials)
                total_r = 0
                count = 0
                for sample in trials:
                    acc = float(sample.get('accuracy', 0.0))
                    ver = float(sample.get('verifier_score', 0.0))
                    
                    if self.cost_metric == "Difficulty_Aware_Normalized_EFLOPS":
                        c = sample.get('difficulty_aware_eflops', 0.0)
                        # Fallback logic if needed, but assuming data is clean or consistent with execute_action
                    else:
                        c = sample.get('normalized_eflops', 0.0)
                    
                    r = w_acc * acc + w_ver * ver - w_cost * c + bias
                    total_r += r
                    count += 1
                
                if count > 0:
                    avg_r = total_r / count
                    if avg_r > best_reward:
                        best_reward = avg_r
                        best_idx = idx
            
            if best_idx != -1:
                # print(f"[Oracle] Selected Best Action: {best_idx} with reward {best_reward:.4f}")
                return best_idx
            else:
                print("[Oracle] No valid action found with data. Using random.")
                if self.allowed_action_indices:
                    return random.choice(self.allowed_action_indices)
                return random.randint(0, len(context_embeddings) - 1)

        ucb_values = []
        ucb_details = []
        
        # Check if in warmup period
        in_warmup = self.total_queries < self.warmup_steps
        
        print(f"[Select Action Debug] Step: {self.total_queries}, Warmup Steps: {self.warmup_steps}, In Warmup: {in_warmup}")

        for i, emb in enumerate(context_embeddings):
            # If we have restricted actions, skip others
            if self.allowed_action_indices is not None and i not in self.allowed_action_indices:
                ucb_values.append(-float('inf'))
                ucb_details.append({'pred': -float('inf'), 'bonus': 0.0, 'ucb': -float('inf')})
                continue

            # Handle input based on algorithm type
            if self.algorithm_name in ['neural_ucb', 'mixed_ucb']:
                x = torch.tensor(emb, dtype=torch.float32).unsqueeze(0)
            else:
                x = np.array(emb, dtype=np.float32)  # LinUCB and Random use numpy array
            
            try:
                # Calculate UCB for logging purposes even during warmup
                result = self.bandit_algorithm.calc_ucb(x)
                if isinstance(result, tuple):
                    ucb, pred, bonus = result
                else:
                    ucb = result
                    pred = ucb
                    bonus = 0.0
            except Exception as e:
                print(f"Failed to calculate UCB: {e}. Using random selection as fallback.")
                ucb = random.random()
                pred = ucb
                bonus = 0.0
            ucb_values.append(ucb)
            ucb_details.append({'pred': pred, 'bonus': bonus, 'ucb': ucb})
        
        # Log UCB details
        self._append_ucb_log(ucb_details)

        if in_warmup:
            # Randomly select action during warmup
            # If allowed_indices is set, choose from them
            if self.allowed_action_indices:
                action_idx = random.choice(self.allowed_action_indices)
            else:
                action_idx = random.randint(0, len(context_embeddings) - 1)
            
            # Important: Log random selection choice for debugging
            # And override ucb_values for the chosen action if needed? 
            # No, standard LinUCB warmup is purely random selection, UCB values are just for observation.
            
            print(f"[Action Selection] Warmup Phase ({self.total_queries + 1}/{self.warmup_steps}). Random Action: {action_idx}")
        else:
            action_idx = int(np.argmax(ucb_values))
            # If ucb_values are all -inf (should not happen if allowed_indices is not empty), fallback
            if ucb_values[action_idx] == -float('inf'):
                 print("[Action Selection] Warning: All UCB values are -inf. Fallback to random choice from allowed.")
                 if self.allowed_action_indices:
                    action_idx = random.choice(self.allowed_action_indices)
                 else:
                    action_idx = random.randint(0, len(context_embeddings) - 1)
        
        # Add logging
        selected_model, (qp, cp, bs) = self.parse_action_idx(action_idx)
        print(f"[Action Selection] Index={action_idx}, Model={selected_model}, QP={qp}, CP={cp}, BS={bs}")
        return action_idx


        
    def parse_action_idx(self, action_idx):
        tts_count = len(self.tts_combinations)
        if tts_count == 0:
            raise ValueError("TTS_COMBINATIONS is empty.")
            
        model_idx = action_idx // tts_count
        tts_idx = action_idx % tts_count
        
        if model_idx >= len(self.llm_models):
            # Possibly action_embeddings_origin.json does not match LLM_MODELS in current config.py
            # Or action_idx is out of expected range
            print(f"Warning: model_idx {model_idx} is out of range for llm_models (len={len(self.llm_models)}). Using last model.")
            model_idx = len(self.llm_models) - 1
            
        return self.llm_models[model_idx], self.tts_combinations[tts_idx]
        
    def execute_action(self, action_idx, user_query, ground_truth=None, query_id=None):
        selected_model, (qp, cp, bs) = self.parse_action_idx(action_idx)
        try:
            # Check offline results first
            if query_id and query_id in self.offline_results:
                return self._execute_offline_action(query_id, action_idx, selected_model, qp, cp, bs, user_query)

            # Select Executor based on configuration
            executor_func = self.llm_executor_mock if self.use_mock_executor else self.llm_executor
            
            # Use injectable execution interface to get call results
            # verifier_score, is_correct, L_in, L_out, extracted_answer, duration = executor_func(...)
            verifier_score, is_correct, L_in, L_out, extracted_answer, duration = executor_func(
                selected_model, (qp, cp, bs), user_query, ground_truth=ground_truth
            )

            # reward is now a dictionary
            reward_result = self.reward_function(
                user_query=user_query,
                model=selected_model,
                verifier_score=verifier_score,
                L_in=L_in,
                L_out=L_out,
                is_correct=is_correct,
                qp=qp,
                cp=cp,
                verifier_model=VERIFIER_MODEL # Pass verifier model from config
            )
            
            # Extract total reward
            total_reward = reward_result.get('total_reward', 0.0)
            
            result = {
                'model': selected_model,
                'qp': qp, 'cp': cp, 'bs': bs,
                'reward': total_reward,
                'reward_details': reward_result, # Store full reward details
                'verifier_score': verifier_score,
                'L_in': L_in,
                'L_out': L_out,
                'is_correct': is_correct,
                'extracted_answer': extracted_answer,
                'duration': duration
            }

            return result, False

        except Exception as e:
            error_result = {'model': selected_model, 'qp': qp, 'cp': cp, 'bs': bs, 'error': str(e), 'reward': -10}
            return error_result, False

            
    def _execute_offline_action(self, query_id, action_idx, selected_model, qp, cp, bs, user_query):
        """
        Execute action using offline results.
        """
        # Construct action key
        action_key = f"{selected_model}+qp{qp}cp{cp}bs{bs}"
        
        query_results = self.offline_results.get(query_id)
        if not query_results:
             print(f"[Offline Exec] No results found for query {query_id}")
             # Check keys in offline_results to debug
             keys_preview = list(self.offline_results.keys())[:5]
             print(f"[Offline Exec] Available keys preview: {keys_preview}")
             return {'reward': 0, 'error': f'No offline results for query {query_id}'}, False
             
        # Try to find action results
             print(f"[Offline Exec] No results found for query {query_id}")
             return {'reward': 0, 'error': 'No offline results for query'}, False
             
        # Try to find action results
        action_results = query_results.get(action_key)
        
        # Fallback: try string index
        if not action_results:
             action_results = query_results.get(str(action_idx))
             
        if not action_results:
             print(f"[Offline Exec] No results found for action {action_key} (Query: {query_id})")
             # Fallback to online execution if configured? 
             # For now, return error or zero reward to avoid crash
             # Check available actions for this query
             avail_actions = list(query_results.keys())[:5]
             print(f"[Offline Exec] Available actions preview: {avail_actions}")
             
             return {'model': selected_model, 'qp': qp, 'cp': cp, 'bs': bs, 'reward': 0, 'error': 'No offline results'}, False
             
        # Sample one result
        sample = random.choice(action_results)
        
        # Extract fields
        accuracy = float(sample.get('Accuracy', sample.get('accuracy', 0.0)))
        verifier_score = float(sample.get('VerifierScore', sample.get('verifier_score', 0.0)))
        eflops = float(sample.get('eFLOPs', sample.get('eflops', 0.0)))
        token_len = int(sample.get('token_len', sample.get('response_len', 100)))
        
        # Extract CSV row index if available
        # When loading CSV, we can store 'csv_line_index' in the sample dict
        csv_line_index = sample.get('csv_line_index', -1)
        
        # Determine Cost Metric
        if self.cost_metric == "Difficulty_Aware_Normalized_EFLOPS":
            cost_val = sample.get('difficulty_aware_eflops', 0.0)
            # Fallback if 0.0 (maybe missing?) - strictly speaking should trust CSV
            if cost_val == 0.0 and sample.get('normalized_eflops', 0.0) > 0:
                 # If diff aware is missing but normal exists, maybe use normal? 
                 # Or maybe it really is 0. 
                 pass
        else:
            # Default to Normalized_EFLOPS
            cost_val = sample.get('normalized_eflops', 0.0)
            
            # Fallback calculation if not in CSV (for backward compatibility)
            if cost_val == 0.0 and 'normalized_eflops' not in sample:
                parallelism = max(1, qp * cp)
                metric = eflops / (max(1, token_len) / parallelism)
                log_val = np.log10(max(1.0, metric))
                normalized_eflops = (log_val - 11.4027) / 5.3646
                cost_val = max(0.0, min(1.0, normalized_eflops))
        
        normalized_eflops = cost_val # This variable name is used in log/result, keep it but it holds the CHOSEN cost metric value

        # Use Custom Weights if provided
        if self.reward_weights:
            # reward_weights: [w_acc, w_ver, w_cost, bias]
            w_acc, w_ver, w_cost, bias = self.reward_weights
        else:
            # Default weights
            w_acc = 0.3669
            w_ver = 0.3669
            w_cost = 0.2662
            bias = 0.2662
            
        reward = w_acc * accuracy + w_ver * verifier_score - w_cost * normalized_eflops + bias
        
        result = {
            'model': selected_model,
            'qp': qp, 'cp': cp, 'bs': bs,
            'reward': reward,
            'verifier_score': verifier_score,
            'is_correct': accuracy,
            'L_in': [100], # Dummy
            'L_out': [token_len],
            'extracted_answer': sample.get('extracted_answer', "OfflineAnswer"),
            'duration': 0.0,
            'csv_line_index': csv_line_index, # Include line index in result
            'reward_details': {
                'accuracy': accuracy,
                'verifier_score': verifier_score,
                'eflops': eflops,
                'normalized_eflops': normalized_eflops, # This is the cost used
                'weight_a': w_acc,
                'weight_b': w_ver,
                'weight_c': w_cost,
                'bias': bias
            }
        }
        
        return result, True

    def llm_executor_mock(self, selected_model, qcpbs, user_query, ground_truth=None):
        """
        Mock execution for testing purposes.
        Returns simulated values.
        """
        qp, cp, bs = qcpbs
        print(f"[Mock Executor] Model: {selected_model}, QP={qp}, CP={cp}, BS={bs}")
        
        # Simulate processing time
        time.sleep(0.1)
        
        # Generate mock response
        # llm_response = f"Mock answer for query: {user_query}"
        
        # Mock verifier score (random but slightly correlated with "complexity" or just random)
        verifier_score = random.uniform(0.5, 1.0)
        
        # Mock is_correct (simulate 80% accuracy)
        is_correct = 1.0 if random.random() < 0.8 else 0.0
        
        # Mock token usage
        # L_out length depends on qp * cp (roughly)
        num_segments = qp * cp
        L_out = [random.randint(10, 50) for _ in range(num_segments)]
        
        # L_in calculation (cumulative)
        L_in = []
        current_len = 100
        for t in L_out:
            L_in.append(current_len)
            current_len += t
            
        return verifier_score, is_correct, L_in, L_out, "MockAnswer", 0.1

    def llm_executor(self, selected_model, qcpbs, user_query, ground_truth=None):
        """
        Call ttsrouter-v1.0 service to execute inference
        URL: http://localhost:7777/tts-router-json (or 17777 for 32B)
        """
        import re
        def extract_answer_flexible(text):
            """Extracts the content within the last \boxed{...} from the text.
               If not found, tries to extract the last LaTeX expression or number."""
            if not text:
                return None
            
            # 1. Try to find \boxed{...}
            matches = [m for m in re.finditer(r"\\boxed\s*\{", text)]
            if matches:
                last_match = matches[-1]
                start = last_match.end()
                balance = 1
                for i in range(start, len(text)):
                    if text[i] == '{':
                        balance += 1
                    elif text[i] == '}':
                        balance -= 1
                    
                    if balance == 0:
                        return text[start:i]
            
            # 2. Fallback: Try to find the last LaTeX expression enclosed in $...$
            latex_matches = re.findall(r"\$([^$]+)\$", text)
            if latex_matches:
                return latex_matches[-1].strip()
                
            return None

        if "32b" in selected_model.lower():
            url = "http://localhost:17777/tts-router-json"
        else:
            url = "http://localhost:7777/tts-router-json"

        qp, cp, bs = qcpbs
        
        # Map model name, ensure case sensitivity matches
        # TTSRouter uses lowercase (qwen3-0.6b), ttsrouter-v1.0 might require uppercase or other formats (Qwen3-0.6B)
        model_map = {
            "qwen3-0.6b": "Qwen3-0.6B",
            "qwen3-1.7b": "Qwen3-1.7B",
            "qwen3-4b": "Qwen3-4B",
            "qwen3-8b": "Qwen3-8B",
            "qwen3-14b": "Qwen3-14B",
            "qwen3-32b": "Qwen3-32B"
        }
        mapped_model = model_map.get(selected_model, selected_model)

        payload = {
            "problems": {
                "problem": user_query,
                "solution": str(ground_truth) if ground_truth else "",
                "lm": mapped_model,
                "beam": {
                    "QP": float(qp),
                    "CP": float(cp),
                    "BS": int(bs)
                }
            },
            "eval_config": {
                "method": "beam_search"
            }
        }

        try:
            print(f"[API Call] Sending request to {url} | Model: {mapped_model} | Params: QP={qp}, CP={cp}, BS={bs}")
            start_time = time.time()
            response = requests.post(url, json=payload, timeout=600) # Set a longer timeout
            duration = time.time() - start_time
            response.raise_for_status()
            print(f"[API Response] Successfully received response, duration: {duration:.2f}s")
            
            data = response.json()

            # Default values
            verifier_score = 0.0
            is_correct = 0.0
            L_in = [100]
            L_out = [100]

            best_results = data.get("best_results", [])
            if best_results and len(best_results) > 0:
                # Take the first result (usually corresponds to user query)
                first_result = best_results[0]
                
                # 1. Extract Verifier Score (take average or max? Assume average of score list)
                scores = first_result.get("score", [])
                if scores:
                    verifier_score = sum(scores) / len(scores)
                
                # 2. Extract is_correct using flexible extraction
                # Get the raw model response text if available, otherwise "answer" field might be pre-processed?
                # The response structure from `ttsrouter-v1.0` puts the generated answer in "answer" field.
                # However, if "answer" is already extracted, we might want the full text to apply our own extraction?
                # Assuming "answer" field contains the full generated text or the relevant part.
                # If "answer" is just the extracted part, then we rely on upstream extraction.
                # But here we want to use OUR flexible extraction.
                # Let's assume `first_result.get("answer", "")` is the model's generation (text).
                raw_response = first_result.get("answer", "")
                
                # Apply flexible extraction
                extracted_ans_flexible = extract_answer_flexible(raw_response)
                
                # If flexible extraction fails, fallback to raw response (or keep None)
                if not extracted_ans_flexible:
                    extracted_ans_flexible = raw_response

                if ground_truth:
                    # Use verify_utils.grade_answer to compare
                    is_correct_bool = grade_answer(str(extracted_ans_flexible), str(ground_truth))
                    is_correct = 1.0 if is_correct_bool else 0.0
                else:
                    is_correct = 0.0

                # 3. Extract L_in and L_out
                response_len = first_result.get("response_len", 0)
                L_out = [int(response_len)]
                L_in = [100] # Default value

                extracted_answer = extracted_ans_flexible # For return

            print(f"Received response: score={verifier_score:.4f}, is_correct={is_correct}, response_len={L_out[0]}")
            print(f"[Answer Comparison] Extracted (Flexible): {extracted_answer} | Ground Truth: {ground_truth}")
            
            return float(verifier_score), is_correct, L_in, L_out, extracted_answer, duration

        except Exception as e:
            print(f"Error calling ttsrouter-v1.0: {e}")
            return 0.0, 0.0, [100], [100], "Error", 0.0
            
    def store_experience(self, context_embedding, reward):
        self.buffer.append((context_embedding, reward))
        
        # Handle input based on algorithm type
        if self.algorithm_name in ['neural_ucb', 'mixed_ucb']:
            x = torch.tensor(context_embedding, dtype=torch.float32).unsqueeze(0)
        else:
            x = context_embedding  # LinUCB and Random use numpy array
            
        self.bandit_algorithm.update(x, reward)
        
    def train(self, local_training_iter=30, lr=None):
        if len(self.buffer) == 0:
            return None
        
        # Use self.lr if lr is not provided
        if lr is None:
            lr = self.lr
            
        contexts, rewards = zip(*self.buffer)
        
        # Note: contexts contains embeddings for the SELECTED actions.
        # This matches standard Bandit training (training on observed (x, r) pairs).
        # We do not need to filter contexts here because store_experience already stores the selected context.
        
        if self.algorithm_name in ['neural_ucb', 'mixed_ucb']:
             loss = self.bandit_algorithm.train(contexts, rewards, local_training_iter=local_training_iter, lr=lr)
        else:
             loss = self.bandit_algorithm.train(contexts, rewards, local_training_iter=local_training_iter)
        
        return loss
        
    def calculate_optimal_reward(self, query_id):
        """
        Calculate the optimal reward for a given query based on offline results
        and current reward configuration.
        """
        if not query_id or query_id not in self.offline_results:
            return 0.0
            
        max_reward = -float('inf')
        
        query_results = self.offline_results[query_id]
        
        # Use Custom Weights if provided
        if self.reward_weights:
            w_acc, w_ver, w_cost, bias = self.reward_weights
        else:
            w_acc = 0.3669
            w_ver = 0.3669
            w_cost = 0.2662
            bias = 0.2662
            
        for action_key, trials in query_results.items():
            # Calculate average reward for this action across trials?
            # Or max possible reward for this action?
            # Usually in offline bandit evaluation (e.g. Replay), we care about the expected reward of the action.
            # But here we have multiple trials for same action-query pair in CSV? 
            # The CSV seems to have 1 row per trial.
            # Let's take the AVERAGE reward of the action as its "true" value for optimization.
            
            total_r = 0
            count = 0
            for sample in trials:
                accuracy = float(sample.get('accuracy', 0.0))
                verifier_score = float(sample.get('verifier_score', 0.0))
                
                # Cost
                if self.cost_metric == "Difficulty_Aware_Normalized_EFLOPS":
                    cost_val = sample.get('difficulty_aware_eflops', 0.0)
                    if cost_val == 0.0 and sample.get('normalized_eflops', 0.0) > 0:
                        pass # Should handle better but for now consistent with execute
                else:
                    cost_val = sample.get('normalized_eflops', 0.0)
                    # Fallback logic omitted for brevity, assuming CSV has it or 0 is fine
                
                r = w_acc * accuracy + w_ver * verifier_score - w_cost * cost_val + bias
                total_r += r
                count += 1
            
            if count > 0:
                avg_r = total_r / count
                if avg_r > max_reward:
                    max_reward = avg_r
                    
        return max_reward if max_reward != -float('inf') else 0.0

    def process_query(self, user_query, ground_truth=None, local_training_iter=30, query_id=None):
        context_embeddings = self.get_context_embedding(user_query, query_id=query_id)
        action_idx = self.select_action(context_embeddings, query_id=query_id)
        selected_context_embedding = self.get_context_embedding(user_query, action_idx, query_id=query_id)

        result, from_cache = self.execute_action(action_idx, user_query, ground_truth=ground_truth, query_id=query_id)
        reward = result.get('reward', 0)
        
        # Calculate Regret
        optimal_reward = self.calculate_optimal_reward(query_id)
        regret = optimal_reward - reward
        result['regret'] = regret
        result['optimal_reward'] = optimal_reward

        self.store_experience(selected_context_embedding, reward)
        loss = self.train(local_training_iter=local_training_iter, lr=self.lr)
        
        # Update cumulative stats
        self.cumulative_reward += reward
        self.total_queries += 1
        
        # Add training info to result
        result['training_loss'] = loss if loss is not None else 0
        result['avg_reward'] = self.cumulative_reward / max(1, self.total_queries)
        result['from_cache'] = from_cache
        result['tokens_used'] = result.get('L_out', [0])[0]
        result['algorithm'] = self.algorithm_name
        result['action_idx'] = action_idx
        
        # Log to CSV
        log_data = result.copy()
        log_data.update(result.get('reward_details', {}))
        log_data['loss'] = loss
        self._append_loss_log(log_data)
        
        return result