"""
FedPOB: Federated Prompt Optimization with Bandits

This module implements a federated learning approach for automatic prompt engineering
using multi-armed bandit algorithms. It combines LinUCB (Linear Upper Confidence Bound)
with federated learning to optimize prompts across multiple agents collaboratively.

Key Features:
- Multi-agent prompt optimization using LinUCB
- Federated learning for collaborative prompt discovery
- Support for various language models (GPT, Qwen, DeepSeek, etc.)
- Automatic prompt generation and evaluation
"""

import json
import random
import torch
import numpy as np
import sys
import os
import time
import argparse
import re
import datetime
import torch.nn.functional as F
import nest_asyncio
from transformers import AutoTokenizer, AutoModel

# Add parent directory to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from automatic_prompt_engineer import config, llm, ape, data, evaluate, template
from data.instruction_induction.load_data import load_data
from evaluation.instruction_induction.exec_accuracy import exec_accuracy_evaluator, exec_evaluator
from evaluation.instruction_induction.utility import set_all_seed

nest_asyncio.apply()

# Task definitions
# Instruction Induction tasks (24 base + 7 additional + 6 extra = 37 total)
instructionInduction_tasks = [
    'active_to_passive', 'antonyms', 'common_concept', 'diff', 'first_word_letter',
    'informal_to_formal', 'larger_animal', 'letters_list', 'negation', 'num_to_verbal',
    'orthography_starts_with', 'rhymes', 'second_word_letter', 'sentence_similarity', 'sentiment',
    'singular_to_plural', 'sum', 'synonyms', 'taxonomy_animal', 'translation_en-de', 'translation_en-es',
    'translation_en-fr', 'word_in_context'
]

extra_instructionInduction_tasks = [
    'auto_categorization', 'object_counting', 'odd_one_out',
    'periodic_elements', 'word_sorting', 'word_unscrambling'
]

# Big-Bench Hard (BBH) tasks (27 total)
bbh_tasks = [
    'boolean_expressions', 'date_understanding', 'disambiguation_qa', 'dyck_languages', 'formal_fallacies',
    'geometric_shapes', 'hyperbaton',
    'logical_deduction_five_objects', 'logical_deduction_seven_objects', 'logical_deduction_three_objects',
    'movie_recommendation', 'multistep_arithmetic_two', 'navigate',
    'penguins_in_a_table', 'reasoning_about_colored_objects', 'ruin_names',
    'salient_translation_error_detection', 'snarks', 'sports_understanding', 'temporal_sequences',
    'tracking_shuffled_objects_five_objects', 'tracking_shuffled_objects_seven_objects',
    'tracking_shuffled_objects_three_objects', 'web_of_lies'
]

# Environment setup
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# API configuration - use environment variables for security
if "OPENROUTER_API_KEY" not in os.environ:
    print("Warning: OPENROUTER_API_KEY not found in environment variables")
if "OPENAI_API_KEY" not in os.environ:
    print("Warning: OPENAI_API_KEY not found in environment variables")

SMOKE_TEST = os.environ.get("SMOKE_TEST")

# Device configuration
tkwargs = {
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.float32,
}

# Default model configurations
DEFAULT_MODEL_NAME = "vicuna"
DEFAULT_API_MODEL = 'chatgpt'


def get_subset_index(N, agents, beta, shuffle=True, overlap='least', size=None):
    """
    Generate subset indices for federated learning data distribution.

    Args:
        N (int): Total number of data points
        agents (int): Number of agents
        beta (float): Overlap ratio for common data
        shuffle (bool): Whether to shuffle the indices
        overlap (str): Overlap strategy ('only' or 'least')
        size (int): Size of subset for each agent

    Returns:
        list: List of index lists for each agent
    """
    indexs = [[] for a in range(agents)]
    common_index = list(range(int(N * beta)))
    flag = int(N * beta)

    if overlap == 'only':
        for a in range(agents):
            indexs[a] = common_index + list(
                range(flag + int(N * (1 - beta) / agents * a), flag + int(N * (1 - beta) / agents * (a + 1))))
    elif overlap == 'least':
        size = int(2 * N * (1 - beta) / agents) if size is None else int(size)
        for a in range(agents):
            indexs[a] = common_index + random.sample(range(flag, N), size)

    if shuffle:
        # Shuffle the order of indices
        temp = list(range(N))
        random.shuffle(temp)
        for a in range(agents):
            indexs[a] = [temp[i] for i in indexs[a]]

    return indexs


def extract_sub_sentence(long_sentence):
    """
    Extract prompts from text enclosed in <prompt></prompt> tags.

    Args:
        long_sentence (str): Input text containing prompt tags

    Returns:
        list: List of extracted prompts
    """
    matches = re.findall('<prompt>(.*?)</prompt>', long_sentence)
    return matches


def mean_pooling(model_output, attention_mask):
    """
    Perform mean pooling on token embeddings with attention mask.

    Args:
        model_output: Model output containing token embeddings
        attention_mask: Attention mask for valid tokens

    Returns:
        torch.Tensor: Mean-pooled sentence embeddings
    """
    token_embeddings = model_output[0]  # First element contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_sen_embedding(model, tokenizer, sentences):
    """
    Generate sentence embeddings using a transformer model.

    Args:
        model: Transformer model for encoding
        tokenizer: Tokenizer for the model
        sentences: List of sentences to encode

    Returns:
        torch.Tensor: Normalized sentence embeddings
    """
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling and normalization
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings


class LLMForwardAPI:
    """
    API wrapper for Language Model forward pass and prompt evaluation.

    This class handles prompt generation, evaluation, and management for
    the federated prompt optimization process.
    """

    def __init__(self, model_name=DEFAULT_MODEL_NAME, eval_datas=None, init_prompt=None,
                 init_qa_gen=None, conf=None, base_conf=None, prompt_gen_datas=None,
                 few_shot_datas=None, agents=None, api_model=DEFAULT_API_MODEL, task_name=None):
        """
        Initialize the LLM Forward API.

        Args:
            model_name (str): Name of the model to use
            eval_datas (list): Evaluation datasets for each agent
            init_prompt (list): Initial prompt templates
            init_qa_gen (callable): Function to generate initial QA pairs
            conf (dict): Configuration dictionary
            base_conf (str): Base configuration file path
            prompt_gen_datas (list): Prompt generation datasets
            few_shot_datas (list): Few-shot learning datasets
            agents (int): Number of agents
            api_model (str): API model type
            task_name (str): Name of the task
        """
        self.model_name = model_name
        self.api_model = api_model
        self.task_name = task_name
        self.init_qa_gen = init_qa_gen
        self.init_prompt = init_prompt[0] if init_prompt else ""

        # Generate initial tokens for each agent
        if init_qa_gen and prompt_gen_datas and agents:
            init_qas = [init_qa_gen(prompt_gen_datas[a]) for a in range(agents)]
            self.init_tokens = [self.init_prompt + init_qas[a] for a in range(agents)]
            self.init_token = self.init_prompt + init_qas[0]
            self.one_prompt_gen_data = prompt_gen_datas[0]

        self.count = 0

        # Evaluation preparation
        self.conf = config.update_config(conf, base_conf)
        self.eval_datas = eval_datas
        self.eval_template = template.EvalTemplate("Instruction: [PROMPT]\n\nInput: [INPUT]\n Output: [OUTPUT]")
        self.demos_template = template.DemosTemplate("Input: [INPUT]\nOutput: [OUTPUT]")

        # Initialize API model if needed
        if self.api_model in ['llama', 'flan-t5']:
            self.api_model_instance = exec_evaluator(self.api_model, self.conf)

        # Set few-shot data
        self.few_shot_datas = few_shot_datas if few_shot_datas is not None else prompt_gen_datas

        self.best_train_perf = 0.0
        self.best_dev_perf = 0.0
        self.best_last_perf = 10
        self.best_prompt = None
        self.num_call = 0
        self.best_instruction = None
        self.prompts_set = dict()
        self.prompts_list = []
        self.parents = []
        self.best_score = 0
        self.score_mean = None
        self.score_std = None
        self.score_min = None
        self.score_max = None
        self.init_user_prompt = None

    def update_init_token(self):
        # randomly choose a qa
        init_qa = self.init_qa_gen(self.one_prompt_gen_data)
        self.init_token = self.init_prompt + init_qa

    def initialize_prompts(self, num_init, task, method):
        ini_prompts_his = {}
        print("=== Model Configuration ===")
        print("Model config:", self.conf['evaluation']['model'])
        print("Using model:", self.conf['evaluation']['model']['gpt_config']['model'])
        print("API base_url:", self.conf['evaluation']['model'].get('base_url', 'OpenAI default'))
        print("==========================")
        model = llm.model_from_config(self.conf['evaluation']['model'])
        if method == 'rephrase':
            model_outputs = model.generate_text(self.init_token, 1, 0.5)
            ini_prompts_his[model_outputs[0]] = 0
            self.init_user_prompt = model_outputs[0]
        while len(ini_prompts_his) < num_init:
            if method == 'induction':
                if task in ['sum', 'first_word_letter', 'periodic_elements', 'active_to_passive', 'boolean_expressions','multistep_arithmetic_two','dyck_languages','snark','translation_en-de','second_word_letter','num_to_verbal','diff','odd_one_out','word_unscrambling','translation_en-es','translation_en-fr','salient_translation_error_detection']:
                    random_prompt = model.generate_text(self.init_token, 1, 1, use_seed=False)[0]
                    model_outputs = model.generate_text(
                        "Rephrase the following instruction: " + random_prompt + "\n the rephrased instruction is: ", 1,
                        1, use_seed=False)
                else:
                    model_outputs = model.generate_text(self.init_token, 1, 0.5)
                # if model_outputs[0] not in ini_prompts_his:
                ini_prompts_his[model_outputs[0]] = 0
                self.update_init_token()

                # Only print progress at specific intervals to avoid excessive output
                if len(ini_prompts_his) % 1 == 0 or len(ini_prompts_his) <= 5:
                    print(f'{task}: {len(ini_prompts_his)}')
            elif method == 'rephrase':
                if task in ['odd_one_out', 'orthography_starts_with']:
                    model_outputs = model.generate_text(
                        "Rephrase the following instruction: " + self.init_user_prompt + "\n the rephrased instruction is: ",
                        1, 1.5, use_seed=False)
                else:
                    model_outputs = model.generate_text(
                        "Rephrase the following instruction: " + self.init_user_prompt + "\n the rephrased instruction is: ",
                        1, 1, use_seed=False)
                ini_prompts_his[model_outputs[0]] = 0
                # Only print progress at specific intervals to avoid excessive output
                if len(ini_prompts_his) % 1 == 0 or len(ini_prompts_his) <= 5:
                    print(f'{task}: {len(ini_prompts_his)}')
            # Only show generated prompts in the initial stage
            if (1 <= len(ini_prompts_his) <= 4):
                print("Generated prompts:", list(ini_prompts_his.keys()))
        return list(ini_prompts_his.keys())

    def selection(self, num_next_gen):
        scores = np.array([self.prompts_set[tmp] for tmp in self.parents])
        num_parents = len(self.parents)
        probability = []
        if np.sum(scores) == 0:
            probability = np.ones(num_parents) / num_parents
        else:
            probability = scores / np.sum(scores)

        all_parents = []
        for i in range(num_next_gen):
            try:
                parent_pair = np.random.choice(self.parents, size=2, replace=False, p=probability)
            except:
                parent_pair = np.random.choice(self.parents, size=2, replace=True, p=probability)
            all_parents += [parent_pair]
        return all_parents

    def evolution(self, all_parents):

        next_gens = []
        model = llm.model_from_config(self.conf['evaluation']['model'])

        template = "Please follow the instruction step-by-step to generate a better prompt.\n1. Cross over the following prompts and generate a new prompt:\nPrompt 1: [prompt_id1].\nPrompt 2: [prompt_id2].\n2. Mutate the prompt generated in Step 1 and generate a final prompt bracketed with <prompt> and </prompt>."
        for parents_ in all_parents:
            template_ = template.replace('[prompt_id1]', parents_[0])
            template_ = template_.replace('[prompt_id2]', parents_[1])
            model_outputs = model.generate_text(template_, 1, 0)
            model_outputs_ = extract_sub_sentence(model_outputs[0])
            if len(model_outputs_) != 0:
                model_outputs = model_outputs_[0]
                print(f"EVOL: {model_outputs}")
            else:
                model_outputs = model_outputs[0]
            next_gens += [model_outputs]
        return next_gens

    def update(self, next_gens):
        next_gens_scores = []
        for gen_ in next_gens:
            score_ = self.eval([gen_], task_name=getattr(self, 'task_name', None))
            next_gens_scores += [score_]
        self.this_iter_best = np.max(next_gens_scores)
        num_parents = len(self.parents)
        parents_next_gen = self.parents + next_gens
        all_scores = [self.prompts_set[tmp] for tmp in parents_next_gen]
        idx_rank = np.argsort(all_scores)
        selected_idx = idx_rank[-num_parents:]
        new_parents = []
        for idx_ in selected_idx:
            new_parents += [parents_next_gen[idx_]]
        self.parents = new_parents

    def eval(self, set_index, instruction=None, test=False, task_name=None):
        """
        Evaluate a prompt instruction.

        Args:
            set_index: Index of the evaluation set
            instruction: Instruction to evaluate
            test (bool): Whether this is a test evaluation
            task_name (str): Name of the task for special handling

        Returns:
            float: Performance score of the instruction
        """
        if instruction[0] in self.prompts_set.keys():
            dev_perf = self.prompts_set[instruction[0]]
        else:
            if self.api_model in ['chatgpt']:
                try:
                    # Special handling for sentence similarity task
                    extra_msgs = ("Options:\n0 - definitely not\n1 - probably not\n2 - possibly\n"
                                "3 - probably\n4 - almost perfectly\n5 - perfectly\n") if task_name == 'sentence_similarity' else ''

                    dev_perf, _ = evaluate.evaluate_prompts(
                        instruction, self.eval_template, set_index, self.eval_datas,
                        self.demos_template, self.few_shot_datas,
                        self.conf['evaluation']['method'], self.conf['evaluation'],
                        extra_msg=extra_msgs
                    )
                    dev_perf = dev_perf.sorted()[1][0]
                except Exception as e:
                    dev_perf = 0.0
                    print(f'Evaluation failed with content filter, setting score to 0.0: {e}')
            else:
                raise NotImplementedError(f"API model {self.api_model} not implemented")

            if not test:
                if dev_perf >= self.best_last_perf:
                    self.count += 1
                if dev_perf >= self.best_dev_perf:
                    self.best_dev_perf = dev_perf
                    self.best_instruction = instruction
                self.prompts_set[instruction[0]] = dev_perf
                self.prompts_list.append((len(self.prompts_list), instruction[0], dev_perf))
                print('Dev loss: {}. Dev perf: {}. Best dev perf: {}'.format(
                    round(float(dev_perf), 4),
                    round(float(dev_perf), 4),
                    round(float(self.best_dev_perf), 4)))
                print('********* Done *********')
        return dev_perf

    def return_best_prompt(self):
        return self.best_instruction

    def return_prompts_set(self):
        return self.prompts_set

    def return_prompts_list(self):
        return self.prompts_list


def run(task, agents, D, indexs, nu, lamdba, n_domain, total_iter,
        n_eval, gpt, init_scale, args, round=None):
    """
    Main function to run the federated prompt optimization experiment.

    Args:
        task (str): Task name to run
        agents (int): Number of agents
        D (int): Dimension parameter
        indexs (list): Data indices for each agent
        nu (float): Nu parameter for confidence bound
        lamdba (float): Lambda parameter for regularization
        n_domain (int): Domain size
        total_iter (int): Total iterations
        n_eval (int): Number of evaluation samples
        gpt (str): Model name
        init_scale (float): Initial scale
        args: Command line arguments
        round: Round number (optional)

    Returns:
        tuple: Experiment results including scores, prompts, and metrics
    """
    # API configuration
    
    base_url = os.environ.get("OPENROUTER_BASE_URL")
    api_key = os.environ.get("OPENROUTER_API_KEY")

    # Model configuration
    actual_model = gpt
    model_name = DEFAULT_MODEL_NAME
    api_model = DEFAULT_API_MODEL

    print(f'Using model: {actual_model}')

    induce_data_inputs, induce_data_outputs, induce_msg = load_data('induce', task)
    test_data_inputs, test_data_outputs, test_msg = load_data('eval', task)
    induce_data = (induce_data_inputs, induce_data_outputs)
    test_data = (test_data_inputs, test_data_outputs)

    # Get size of the induce data
    induce_data_size = len(induce_data[0])
    prompt_gen_size = min(int(induce_data_size * 0.5), 100)
    # Induce data is split into prompt_gen_data and eval_data
    prompt_gen_datas = [[] for a in range(agents)]
    eval_datas = [[] for a in range(agents)]

    for a in range(agents):
        prompt_gen_datas[a], eval_datas[a] = data.create_split(
            induce_data, prompt_gen_size, a)
        prompt_gen_datas[a] = prompt_gen_datas[a][0], [random.sample(output, 1)[0]
                                                       for output in prompt_gen_datas[a][1]]


    # Data is in the form input: single item, output: list of items
    # For prompt_gen_data, sample a single item from the output list

    demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
    eval_template = "Instruction: [PROMPT]\n\nInput: [INPUT]\n\nOUTPUT: [OUTPUT]"  # change the evaluation template
    init_prompt = ['\n']
    prompt_gen_template = "[full_DEMO]\n\nThe instruction was to"  # Temporary setting, adjust later based on model type

    base_conf = '../experiments/configs/instruction_induction.yaml'
    conf = {
        'generation': {
            'num_subsamples': 1,
            'num_demos': 5,
            'num_prompts_per_subsample': 20,
            'model': {
                'name': 'GPT_forward',
                'batch_size': 500,
                'base_url': base_url,
                'api_key': api_key,
                'gpt_config': {
                    'model': actual_model,
                    'temperature': 0.7,  # Keep consistent with GPT-3.5
                    'max_tokens': 50,    # Restore to original length limit
                    'top_p': 1.0,        # Keep consistent with GPT-3.5
                    'frequency_penalty': 0.0,
                    'presence_penalty': 0.0
                }
            }
        },
        'evaluation': {
            'method': exec_accuracy_evaluator,
            'task': task,
            'num_samples': min(n_eval, len(eval_datas[0][0])),
            'model': {
                'name': 'GPT_forward',
                'batch_size': 20,
                'base_url': base_url,
                'api_key': api_key,
                'gpt_config': {
                    'model': actual_model,
                    'temperature': 0.7,  
                    'max_tokens': 500,    
                    'top_p': 1.0,        
                    'frequency_penalty': 0.0,
                    'presence_penalty': 0.0
                }
            }
        }
    }
    #limit the number of communication rounds
    communication_flag = np.zeros(agents)
    last_information_metric = np.zeros(agents)
    information_metric = np.zeros(agents)
    t_last = 0

    # start a prompt and use rephrasing to generate the initial instructions
    # make the demo automatically
    def init_qa_gen(prompt_gen_data):
        subsampled_data = data.subsample_data(prompt_gen_data, conf['generation']['num_demos'])
        prompt_gen_template_ = template.InitQATemplate(prompt_gen_template)
        d_template = template.DemosTemplate(demos_template)
        demos = d_template.fill(subsampled_data)
        return prompt_gen_template_.fill(demos)

    model_forward_api = LLMForwardAPI(
        model_name=model_name,
        eval_datas=eval_datas,
        init_prompt=init_prompt,
        init_qa_gen=init_qa_gen,
        conf=conf,
        base_conf=base_conf,
        prompt_gen_datas=prompt_gen_datas,
        agents=agents,
        api_model=api_model,
        task_name=task
    )
    # print(set_all_seed(args.trial))
    # check whether a certain file exists
    
    
    if args.gpt.startswith('gpt-3.5'):
        model_type = "gpt-3.5-turbo"
    elif args.gpt.startswith('openai/gpt-4o-mini'):
        model_type = "gpt-4o-mini"
    elif args.gpt.startswith('deepseek/'):
        model_type = "deepseek"
    elif args.gpt.startswith('google/gemini'):
        model_type = "gemini"
    elif args.gpt.startswith('qwen/qwen3-235b-a22b-2507'):
        model_type = "qwen3-235b"
    elif args.gpt.startswith('qwen/qwen3-coder'):
        model_type = "qwen3-coder"
    elif args.gpt.startswith('qwen/'):
        model_type = "qwen"
    else:
        raise NotImplementedError('no such model')

    if model_type.startswith('qwen') or model_type.startswith('deepseek'):
        prompt_gen_template = "[full_DEMO]\n\nWrite a concise instruction in 20 words or less. No explanation needed. The instruction was to"

    # Uniformly use the gpt-3.5 query file to ensure all models evaluate the same set of prompts.
    if args.candidate_method == "induction":
        folder = "./query/gpt-3.5-turbo"
        path_ = f"./query/gpt-3.5-turbo/{task}_{args.n_domain}"
    elif args.candidate_method == "rephrase":
        folder = "./query/gpt-3.5-turbo"
        path_ = f"./query/gpt-3.5-turbo/{task}_{args.n_domain}_rephrase"

    print('path_ = ', path_)

    if os.path.exists(path_):
        with open(path_, 'r', encoding='utf-8') as fp:
            domains = json.load(fp)
            init_instructions = domains['instructions']
        print(f"Loading existing gpt-3.5 query file: {path_}")
    else:
        # Only generate when gpt-3.5 query file doesn't exist, and must use gpt-3.5 model to generate
        print("gpt-3.5 query file doesn't exist, using gpt-3.5 model to generate unified prompt set")
        if not os.path.exists(folder):
            os.makedirs(folder)  # Use makedirs to create multi-level directories


        init_instructions = []

        init_instructions = model_forward_api.initialize_prompts(args.n_domain, task, args.candidate_method)
        print(f"Generated {len(init_instructions)} prompts using GPT-3.5")

        with open(path_, 'w') as fp:
            domains = {"instructions": init_instructions}
            json.dump(domains, fp, indent=4)
        print(f"New gpt-3.5 query file saved: {path_}")

    sen_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    sen_model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')

    sen_embeddings = get_sen_embedding(sen_model, sen_tokenizer, init_instructions)
    if round is not None:
        _index = torch.tensor(indexs[round])
        sen_embeddings = sen_embeddings[_index]
    sen_embeddings = sen_embeddings.to(**tkwargs)
    prompt_domain = args.n_domain if round is None else len(indexs[round])
    # shape = (100，768)
    #print('instructions:', init_instructions)
    #print('sen_embeddings.shape', sen_embeddings.shape)

    test_num = 1
    # test_num = 2
    all_tmp_scores = []

    for tmp in range(test_num):
        prompt_tmp = np.random.choice(init_instructions)
        # Currently only doing normalization for one dev_set
        score_tmp = model_forward_api.eval(0, [prompt_tmp], test=True, task_name=task)
        all_tmp_scores += [score_tmp]
        print("********")
    model_forward_api.score_mean = np.mean(all_tmp_scores)
    model_forward_api.score_std = np.std(all_tmp_scores)
    model_forward_api.score_min = np.min(all_tmp_scores)
    model_forward_api.score_max = np.max(all_tmp_scores)

    # Note: Weight update function removed as weights are not used in current implementation


    '''
    initialize
    '''
    select_idx_historys = [[] for a in range(agents)]
    instruction_select_historys = [[] for a in range(agents)]
    feature = sen_embeddings.shape[-1]
    W_sync = torch.zeros((feature, feature), dtype=torch.float64).to(**tkwargs)
    W_new = torch.zeros((agents, feature, feature), dtype=torch.float64).to(**tkwargs)
    b_sync = torch.zeros((feature, 1), dtype=torch.float64).to(**tkwargs)
    b_new = torch.zeros((agents, feature, 1), dtype=torch.float64).to(**tkwargs)
    V_t = torch.zeros((agents, feature, feature), dtype=torch.float64).to(**tkwargs)
    V_t_inverse = torch.zeros((agents, feature, feature), dtype=torch.float64).to(**tkwargs)
    theta_hat_t = torch.zeros((feature, 1), dtype=torch.float64).to(**tkwargs)
    V_last = lamdba * torch.eye(feature, dtype=torch.float64).to(**tkwargs)
    best_arm = -1

    best_r = [0 for a in range(agents)]
    best_values = []
    now_values = [[] for a in range(agents)]
    best_instruction_over_iter = [[] for a in range(agents)]
    selected_instruction_over_iter = [[] for a in range(agents)]



    score_file_name = "./query_score/"+ model_type + '/'  + args.task + "_" + str(args.n_domain) + ".json"
    print('score_file_name',score_file_name)

    # Ensure query_score directory exists
    score_folder = f"./query_score/{model_type}"
    if not os.path.exists(score_folder):
        os.makedirs(score_folder)
    prompt_scores = []
    if os.path.exists(score_file_name):
        with open(score_file_name, "r", encoding='utf-8') as f:
            json_data = json.load(f)
            prompt_scores = json_data[args.task]  # Extract scores for corresponding task based on JSON structure
        print(f'Loading existing score file: {score_file_name}')
    else:
        print('Score file does not exist, starting calculation...')
        for i in range(n_domain):
            score = model_forward_api.eval(0, [init_instructions[i]], task_name=task)
            prompt_scores.append(score)
            print('using '+model_type+' evaluating prompt '+str(i)+' '+str(score))

        # Save as JSON format
        data_to_save = {args.task: prompt_scores}
        with open(score_file_name, "w", encoding='utf-8') as f:
            json.dump(data_to_save, f, indent=2, ensure_ascii=False)
        print(f'New score file saved: {score_file_name}')

    # print('prompt_scores:',prompt_scores)

    prompt_scores = torch.tensor(prompt_scores)

    best_arm, best_score = torch.argmax(prompt_scores), torch.max(prompt_scores)
    regrets = [[] for a in range(agents)]
    acc_regrets = [[0] for a in range(agents)]
    rewards = [[] for a in range(agents)]
    acc_rewards = [[0] for a in range(agents)]
    best_rewards = [[0] for a in range(agents)]
    max_scores = [max(prompt_scores[indexs[a]]).item() for a in range(agents)]
    #print('max_scores',max_scores)

    cnt = 0
    for t in range(50):  
        # Weights calculation removed - not used in current implementation
        arm_selects = []
        for a in range(agents):
            # Random selection for initial iterations, though we don't use it
            if t < 0:
                arm_select = np.random.choice(prompt_domain, 1, replace=False)
                arm_selects.append(arm_select)
            else:
                V_t[a, :, :] = lamdba * torch.eye(feature).to(**tkwargs) + W_sync + W_new[a, :, :]
                V_t_inverse[a, :, :] = torch.inverse(V_t[a, :, :])
                theta_hat_t = V_t_inverse[a, :, :] @ (b_sync + b_new[a, :, :])
                # UCB
                mu = sen_embeddings @ theta_hat_t
                uncertainty = nu * torch.sqrt(
                    torch.diagonal(sen_embeddings @ V_t_inverse[a, :, :] @ sen_embeddings.T)).reshape(-1, 1).to(
                    **tkwargs)
                Lin_UCB = mu + uncertainty
                if indexs is None:
                    arm_select = torch.argmax(Lin_UCB)

                else:
                    sub_Lin_UCB = Lin_UCB[indexs[a]]
                    flag = torch.argmax(sub_Lin_UCB)
                    arm_select = torch.tensor(indexs[a][flag])
                arm_selects.append(arm_select)

        # evaluation
        for a in range(agents):
            # print("Agent {} Start selecting...".format(a))
            arm_select = arm_selects[a].item()
            x_a = sen_embeddings[arm_select].reshape(-1, 1)

            select_idx_historys[a] += [[arm_select]]
            score = prompt_scores[arm_select]
            instruction_select_historys[a] += [
                (init_instructions[arm_select], score)]

            # Record the instruction actually selected in current iteration
            selected_instruction_over_iter[a] += [(t, init_instructions[arm_select])]
            regret = best_score - score
            regrets[a].append(regret.item())
            rewards[a].append(score.item())
            acc_regrets[a].append((acc_rewards[a][-1] + regret.item()))
            acc_rewards[a].append(acc_rewards[a][-1] + score.item())

            if score > best_rewards[a][-1]:
                best_rewards[a].append(score.item())
                best_arm = arm_select
            else:
                best_rewards[a].append(best_rewards[a][-1])
            # print('Agent' + str(a) + '  r=' + str(r))
            now_values[a] += [score]
            # Only save iteration number and instruction text, not scores
            best_instruction_over_iter[a] += [(t, init_instructions[best_arm])]
            
            best_r[a] = max(score, best_r[a])

            # print("Selected arm by Agent{}: ".format(a), arm_select, 'r = ', score)
            # print('best_arm', best_arm, 'best_r', prompt_scores[best_arm])
            #print('best_instruction_over_iter[a]',best_instruction_over_iter[a])
            
            # print("iter {0} --- reward: {1}".format(t, r))
            # print(f"Best value found till now: {best_r}")

            # update
            W_new[a, :, :] += x_a @ x_a.T
            b_new[a, :] += score * x_a
            V_t[a, :, :] = lamdba * torch.eye(feature).to(**tkwargs) + W_sync + W_new[a, :, :]
            criterion = (torch.sum(torch.log(torch.diagonal(V_t[a, :, :], 0))) - \
                         torch.sum(torch.log(torch.diagonal(V_last, 0))))
            # print('criterion=', criterion)
            if (criterion * (t - t_last) > D):
                communication_flag[a] = 1


            best_values.append(best_r)

        # update
        if np.any(communication_flag):
            if(t <= 50):
                cnt += 1
            t_last = t
            communication_flag = np.zeros(agents)
            # reset W_new and b_new after uploading to server
            for a in range(agents):
                W_sync += W_new[a, :, :]
                b_sync += b_new[a, :, :]    
                W_new[a, :, :] = torch.zeros((feature, feature), dtype=torch.float64)
                b_new[a, :, :] = torch.zeros((feature, 1), dtype=torch.float64)
            V_last = lamdba * torch.eye(feature).to(**tkwargs) + W_sync
    print('communication round: ', cnt)
    print('best_instruction_over_iter[0]',best_instruction_over_iter[0])
    prompts = [best_instruction_over_iter[a][-1][1] for a in range(agents)]
    prompts_set = model_forward_api.return_prompts_set()

    # Evaluate on test data
    # print('Evaluating on test data...')

    test_conf = {
        'generation': {
            'num_subsamples': 3,
            'num_demos': 5,
            'num_prompts_per_subsample': 0,
            'model': {
                'gpt_config': {
                    'model': gpt
                }
            }
        },
        'evaluation': {
            'method': exec_accuracy_evaluator,
            # option: accuracy (cannot use likelihood here due to the textual outputs from ChatGPT do not have log prob)
            'task': task,
            'num_samples': min(n_eval, len(test_data[0])),
            'model': {
                "name": "GPT_forward",
                'gpt_config': {
                    'model': gpt
                }
            }
        }
    }
    test_scores = []
    for a in range(agents):
        '''
        test_res = ape.evaluate_prompts(prompts=prompts[a],
                                        eval_template=eval_template,
                                        eval_data=test_data,
                                        few_shot_data=prompt_gen_data,
                                        demos_template=demos_template,
                                        conf=test_conf,
                                        base_conf=base_conf)
        test_res = test_res[0]
        test_score = test_res.sorted()[1][0]
        test_scores.append(test_score)
        '''
        test_scores.append(0.01)
    #print('rewards', rewards)
    return test_scores, prompts, prompts_set, best_values, now_values, best_instruction_over_iter, selected_instruction_over_iter, init_instructions, instruction_select_historys, rewards, acc_rewards, best_rewards, max_scores, cnt
    # print(f'Test score on ChatGPT: {test_scores}')


def parse_args():
    parser = argparse.ArgumentParser(description="InstructZero pipeline")
    parser.add_argument(
        "--task",
        type=str,
        default='boolean_expressions',
        help="The name of the dataset to use (via the datasets library).",
    )
    # n_prompt_tokens parameter removed - not used in current implementation
    parser.add_argument(
        "--agents",
        type=int,
        default=6,
        help="Set the number of agents."
    )
    # Alpha parameter removed - not used in current implementation

    parser.add_argument(
        "--D",
        type=float,
        default=10.0,
        help="Set threshold for aggregation."
    )

    parser.add_argument(
        "--nu",
        type=float,
        default=0.3,
        help="Set the parameter nu."
    )
    parser.add_argument(
        "--lamdba",
        type=float,
        default=1,
        help="Set the lamdba parameter."
    )
    # n_init parameter removed - not used in current implementation
    parser.add_argument(
        "--n_domain",
        type=int,
        default=6,  
        help="Set the number of domain."
    )
    parser.add_argument(
        "--total_iter",
        type=int,
        default=100,
        help="Set the number of total queries."
    )
    parser.add_argument(
        "--n_eval",
        type=int,
        default=50,
        help="Set the number of domains to be evaluated at each ucb iteration."
    )
    parser.add_argument(
        "--name",
        type=str,
        default="",
        help="Set the name of the experiments."
    )
    parser.add_argument(
        "--gpt",
        type=str,
        default="gpt-3.5-turbo",
        help="Which model to use. Examples: gpt-3.5-turbo, openai/gpt-4o-mini, deepseek/deepseek-chat, google/gemini-2.5-flash, qwen/qwen3-235b-a22b-2507"
    )
    parser.add_argument(
        "--init_scale",
        type=float,
        default=1,
        help="Which scale to use."
    )
    parser.add_argument(
        "--trial",
        type=int,
        default=0,
        help="Trial ID."
    )
    parser.add_argument(
        "--candidate_method",
        type=str,
        default='induction',
        help="The way to generate candidates."
    )
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()

    # Configure tasks to run - using BBH tasks starting from index 13
    run_tasks = bbh_tasks

    # Configure number of agents for experiments
    run_agents = [10]

    for run_task in run_tasks:
        for run_agent in run_agents:
            # Generate data subset indices for each agent
            indexs = get_subset_index(args.n_domain, run_agent, 0.0, True, 'least', min(100, args.n_domain))
            args.task = run_task
            args.agents = run_agent

            # Determine result folder based on model type
            if args.gpt.startswith('deepseek/'):
                result_folder = "./all_results/FedPOB/Deepseek"
            elif args.gpt.startswith('gpt-3.5'):
                result_folder = "./all_results/FedPOB/gpt-3.5-turbo"
            elif args.gpt.startswith('openai/gpt-4o-mini'):
                result_folder = "./all_results/FedPOB/gpt-4o-mini"
            elif args.gpt.startswith('google/gemini'):
                result_folder = "./all_results/FedPOB/Gemini"
            elif args.gpt.startswith('qwen/qwen3-235b-a22b-2507'):
                result_folder = "./all_results/FedPOB/qwen3-235b"
            else:
                result_folder = "./all_results/FedPOB/default"
                print(f"Warning: Unknown model type {args.gpt}, using default folder")

            # Ensure result directory exists
            if not os.path.exists(result_folder):
                os.makedirs(result_folder)
                print(f'Created result directory: {result_folder}')

            # Generate log file name
            log_file_name = (f"{result_folder}/_dataset_{args.task}"
                           f"_prompt_number_{args.n_domain}"
                           f"_agents_{args.agents}"
                           f"_D_{args.D}"
                           f"_lam_{args.lamdba}"
                           f"_nu_{args.nu}.json")
            print(f'Log file: {log_file_name}')

            # Store all experiment results
            all_experiments = []

            # Run multiple experiments for statistical significance
            for i in range(5):
                print(f"Setting random seed: {set_all_seed(i)}")
                indexs = get_subset_index(args.n_domain, run_agent, 0.0, True, 'least', min(100, args.n_domain))

                # Run the main experiment
                (test_scores, prompts, prompts_set, best_values, now_values,
                 best_instruction_over_iter, selected_instruction_over_iter,
                 init_instructions, instruction_select_history, rewards,
                 acc_rewards, best_rewards, max_scores, cnt) = run(
                    task=args.task,
                    agents=args.agents,
                    D=args.D,
                    indexs=indexs,
                    nu=args.nu,
                    lamdba=args.lamdba,
                    n_domain=args.n_domain,
                    total_iter=args.total_iter,
                    n_eval=args.n_eval,
                    gpt=args.gpt,
                    init_scale=args.init_scale,
                    args=args,
                    round=None
                )

                # Prepare experiment result data
                experiment_result = {
                    'experiment_id': i,
                    'task': args.task,
                    'agents': args.agents,
                    'total_iter': args.total_iter,
                    'D': args.D,
                    'communication_rounds': cnt,
                    'lambda': args.lamdba,
                    'nu': args.nu,
                    'n_domain': args.n_domain,
                    'model': args.gpt,
                    'best_prompts': prompts,
                    'best_instruction_over_iter': best_instruction_over_iter,
                    'selected_instruction_over_iter': selected_instruction_over_iter,
                    'rewards': rewards.tolist() if hasattr(rewards, 'tolist') else rewards,
                    'best_rewards': best_rewards.tolist() if hasattr(best_rewards, 'tolist') else best_rewards,
                    'acc_rewards': acc_rewards.tolist() if hasattr(acc_rewards, 'tolist') else acc_rewards,
                    'max_scores': max_scores.tolist() if hasattr(max_scores, 'tolist') else max_scores,
                    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
                }

                # Add to experiment results list
                all_experiments.append(experiment_result)
                print(f'Task {run_task} Agent={run_agent} Experiment={i} Finished!')

            # Save all experiment results to JSON file
            try:
                with open(log_file_name, 'w', encoding='utf-8') as f:
                    json.dump(all_experiments, f, indent=2, ensure_ascii=False)
                print(f'All experiment results saved to: {log_file_name}')
            except Exception as e:
                print(f'Error saving results file: {e}')