"""
FedPOB-Pref: 

This module implements the FedPOB-Pref algorithm 
"""

import torch
import numpy as np
from tqdm import tqdm
import os

from common_components import FederatedLinearDuelingBanditEnvironment

MODEL_FOLDER_MAPPING = {
    'gpt-3.5-turbo': 'gpt-3.5-turbo',
    'gpt-4': 'gpt4',
    'gpt-4-turbo': 'gpt4_turbo',
    'gpt-4o-mini': 'gpt-4o-mini',
    'qwen3-235b': 'qwen3-235b',
    'deepseek/deepseek-chat': 'deepseek',
    'deepseek/deepseek-chat-v3-0324': 'deepseek',
    'qwen/qwen-2.5-72b-instruct': 'qwen',
    'qwen/qwen-2.5-32b-instruct': 'qwen_32b',
    'qwen/qwen3-235b-a22b-2507':'qwen3-235b',
    'anthropic/claude-3-sonnet': 'claude3_sonnet',
    'anthropic/claude-3-haiku': 'claude3_haiku',
    'meta-llama/llama-3.1-70b-instruct': 'llama3_70b',
    'meta-llama/llama-3.1-8b-instruct': 'llama3_8b',
    'mistralai/mistral-7b-instruct': 'mistral_7b',
    'google/gemini-pro': 'gemini_pro',
}



class FederatedClient_optimal:
    """
    Federated client for preference-based optimization.

    This class implements a federated learning client that uses dueling bandits
    for preference-based optimization with local training capabilities.
    """
    def __init__(self, client_id, feature_dim, learning_rate, lambda_reg):
        """
        Initialize federated client.

        Args:
            client_id: Unique identifier for the client
            feature_dim: Dimension of feature vectors
            learning_rate: Learning rate for local optimization
            lambda_reg: Regularization parameter
        """
        self.client_id = client_id
        self.feature_dim = feature_dim
        self.lambda_reg = lambda_reg
        self.learning_rate = learning_rate
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.V_t = lambda_reg * torch.eye(feature_dim, device=device)
        self.W_new = torch.zeros((feature_dim, feature_dim), device=device)
        self.local_gradient = torch.zeros(feature_dim, device=device)

        # History tracking
        self.history_arm_1 = []
        self.history_arm_2 = []
        self.history_selected_arm = []
        self.history_winner = []
        self.history_arm_1_index = []
        self.history_arm_2_index = []

        # Best arm tracking
        self.best_r = -np.inf
        self.best_arm = None
        self.best_values = []
        self.best_instruction_over_iter = []
        self.best_r1 = -np.inf
        self.best_arm1_idx = None

    def select_arms(self, theta_local, context, beta_t, W_sync, score_list):
        """
        Select two arms for dueling bandit comparison.

        Args:
            theta_local: Local parameter vector
            context: Context matrix
            beta_t: Confidence parameter
            W_sync: Synchronized covariance matrix
            score_list: Ground truth scores for debugging

        Returns:
            tuple: (arm1_idx, arm2_idx) indices of selected arms
        """
        # Select first arm (greedy)
        arm1_idx = torch.argmax(context @ theta_local).item()
        arm1 = context[arm1_idx]

        # Select second arm using UCB
        self.V_t = W_sync
        V_t_inv = torch.linalg.inv(self.V_t)
        ARM1 = arm1.unsqueeze(0).expand(context.shape[0], -1)

        diff = context - ARM1
        mu = diff @ theta_local
        ucb = mu + 2 * beta_t * torch.sqrt(diff @ V_t_inv @ diff.T).diag()
        arm2_idx = torch.argmax(ucb).item()

        return arm1_idx, arm2_idx

    def compute_w_update(self, arm1, arm2):
        """Computes the local W_new update."""
        diff = 10 * (arm1 - arm2)
        self.W_new = torch.outer(diff, diff)
        return self.W_new

    def local_train_feddyn(self, theta_local, theta_global, grad_history):
        """
        Local training using FedDyn algorithm.

        Args:
            theta_local: Local parameter vector
            theta_global: Global parameter vector
            grad_history: Gradient history for FedDyn

        Returns:
            tuple: (updated_theta_local, updated_grad_history)
        """
        theta_local = theta_local.clone().detach().requires_grad_(True)

        def loss_fn():
            loss = 0
            for x1, x2, y in zip(self.history_arm_1, self.history_arm_2, self.history_winner):
                diff = x1 - x2
                score = torch.dot(theta_local, diff)
                if y == 1:
                    loss -= torch.log(torch.sigmoid(score))
                else:
                    loss -= torch.log(torch.sigmoid(-score))

            loss += self.lambda_reg / 2 * torch.sum((theta_local - theta_global) ** 2)
            loss -= torch.dot(grad_history, theta_local)
            return loss

        optimizer = torch.optim.Adam([theta_local], lr=self.learning_rate)

        for _ in range(30):
            optimizer.zero_grad()
            loss = loss_fn()
            loss.backward()
            optimizer.step()

        grad_history -= self.lambda_reg * (theta_local.detach() - theta_global)
        return theta_local.detach(), grad_history


class CentralServer_optimal:
    """
    Central server for federated preference-based optimization.

    Manages global model parameters and coordinates federated learning
    across multiple clients using FedDyn algorithm.
    """

    def __init__(self, feature_dim, num_clients, learning_rate, lambda_reg=1.0):
        """
        Initialize central server.

        Args:
            feature_dim: Dimension of feature vectors
            num_clients: Number of federated clients
            learning_rate: Learning rate for global optimization
            lambda_reg: Regularization parameter
        """
        self.theta_sync = torch.nn.Parameter(torch.zeros(feature_dim, device="cuda"))
        self.optimizer = torch.optim.Adam([self.theta_sync], lr=learning_rate, weight_decay=lambda_reg)
        self.theta_global = torch.zeros(feature_dim, device="cuda")
        self.W_sync = lambda_reg * torch.eye(feature_dim, device="cuda")
        self.feature_dim = feature_dim
        self.lambda_reg = lambda_reg
        self.num_clients = num_clients
        self.learning_rate = learning_rate

    def aggregate_w_updates(self, w_updates):
        """Aggregate W updates from all clients."""
        for W_new in w_updates:
            self.W_sync += W_new

    def update_global_feddyn(self, local_thetas, grad_histories):
        """
        Update global model using FedDyn algorithm.

        Args:
            local_thetas: List of local parameter vectors
            grad_histories: List of gradient histories

        Returns:
            torch.Tensor: Updated global parameter vector
        """
        N = len(local_thetas)
        theta_global = torch.zeros_like(local_thetas[0])

        for theta_i, g_i in zip(local_thetas, grad_histories):
            theta_global += theta_i - (1.0 / self.lambda_reg) * g_i

        theta_global /= N
        return theta_global


def get_model_folder_name(gpt_model):
    """Get the folder name corresponding to the model"""
    return MODEL_FOLDER_MAPPING.get(gpt_model, 'unknown_model')


def create_result_directory(gpt_model):
    """Create result save directory"""
    model_folder = get_model_folder_name(gpt_model)
    result_folder = f"../all_results/FedPOB-Pref/{model_folder}"

    if not os.path.exists(result_folder):
        os.makedirs(result_folder)

    return result_folder


def generate_filename(args, task):
    """Generate detailed filename"""
    def format_float(value):
        if isinstance(value, float):
            if value == int(value):
                return str(int(value))
            else:
                return f"{value:.3f}".rstrip('0').rstrip('.')
        return str(value)

    params = {
        'dataset': task,
        'delta': format_float(args.delta),
        'lr': format_float(args.learning_rate),
        'agents': args.num_clients,
        'noise': args.noise,
        'lambda_reg': format_float(args.lambda_reg),
        'num_iterations': args.num_iterations,
    }

    filename_parts = []
    for key, value in params.items():
        filename_parts.append(f"{key}_{value}")

    return "_" + "_".join(filename_parts) + ".json"


def get_result_file_path(args, task):
    """Generate complete result file path"""
    result_folder = create_result_directory(args.gpt)
    filename = generate_filename(args, task)
    return os.path.join(result_folder, filename).replace('\\', '/')


def run_FedPOB_Pref(task, n_prompt_tokens, delta, lambda_reg, noise, n_domain, num_iterations,
                   learning_rate, random_proj, intrinsic_dim, n_eval, gpt, init_scale, pooling,
                   num_clients, trial, args):
    """
    Run the complete FedPOB-Pref method process.

    Args:
        task: Task name
        n_prompt_tokens: Number of prompt tokens
        delta: Confidence parameter
        lambda_reg: Regularization parameter
        noise: Noise level
        n_domain: Number of domains
        num_iterations: Number of iterations
        learning_rate: Learning rate
        random_proj: Random projection method
        intrinsic_dim: Intrinsic dimension
        n_eval: Number of evaluations
        gpt: Model name
        init_scale: Initialization scale
        pooling: Pooling method
        num_clients: Number of clients
        trial: Trial number
        args: Additional arguments

    Returns:
        tuple: Experiment results
    """
    import json
    import random
    import os
    from transformers import AutoTokenizer, AutoModel
    from automatic_prompt_engineer import data, template
    from data.instruction_induction.load_data import load_data
    from evaluation.instruction_induction.exec_accuracy import exec_accuracy_evaluator
    from experiments.evaluation.instruction_induction.utility import set_all_seed
    from common_components import LMForwardAPI, get_sen_embedding, tkwargs

    use_openrouter = 1
    base_url = "https://openrouter.ai/api/v1" if use_openrouter else None
    api_key = os.environ.get("OPENROUTER_API_KEY") if use_openrouter else os.environ.get("OPENAI_API_KEY")

    actual_model = get_model_folder_name(gpt)
    print('Model used:')
    print(actual_model)

    average_score = []

    induce_data, test_data = load_data('induce', task)[:2], load_data('eval', task)[:2]
    msgs = ''

    induce_data_size = len(induce_data[0])
    prompt_gen_size = min(int(induce_data_size * 0.5), 100)

    prompt_gen_data, eval_data = data.create_split(induce_data, split_size=prompt_gen_size, seed=None)

    # Data is in the form input: single item, output: list of items
    # For prompt_gen_data, sample a single item from the output list
    prompt_gen_data = prompt_gen_data[0], [random.sample(output, 1)[0]
                                           for output in prompt_gen_data[1]]

    demos_template = "Input: [INPUT]\nOutput: [OUTPUT]"
    eval_template = "Instruction: [PROMPT]\n\nInput: [INPUT]\n\nOUTPUT: [OUTPUT]"  # change the evaluation template
    init_prompt = ['\n']
    prompt_gen_template = "[full_DEMO]\n\nThe instruction was to"

    base_conf = '../experiments/configs/instruction_induction.yaml'
    model_config = {
        'name': 'GPT_forward',
        'batch_size': 1,
        'base_url': base_url,
        'api_key': api_key,
        'gpt_config': {
            'model': actual_model
        }
    }

    print(f"Model configuration debug info:")
    print(f"  - use_openrouter: {use_openrouter}")
    print(f"  - base_url: {base_url}")
    print(f"  - api_key: {api_key[:20]}..." if api_key else "  - api_key: None")
    print(f"  - actual_model: {actual_model}")
    print(f"  - model_config: {model_config}")

    conf = {
        'generation': {
            'num_subsamples': 1,
            'num_demos': 5,
            'num_prompts_per_subsample': 20,
            'model': model_config
        },
        'evaluation': {
            'method': exec_accuracy_evaluator,
            'task': task,
            'num_samples': min(50, len(eval_data[0])),
            'num_few_shot': 5,
            'model': model_config
        }
    }


    def init_qa_gen():
        subsampled_data = data.subsample_data(prompt_gen_data, conf['generation']['num_demos'])
        prompt_gen_template_ = template.InitQATemplate(prompt_gen_template)
        d_template = template.DemosTemplate(demos_template)
        demos = d_template.fill(subsampled_data)
        return prompt_gen_template_.fill(demos)

    model_forward_api = LMForwardAPI(model_name='vicuna', eval_data=eval_data, init_prompt=init_prompt,
                                     init_qa_gen=init_qa_gen, conf=conf, base_conf=base_conf,
                                     prompt_gen_data=prompt_gen_data,
                                     n_prompt_tokens=n_prompt_tokens, random_proj=random_proj,
                                     intrinsic_dim=intrinsic_dim, eval_extra_msg=msgs)
    print(set_all_seed(args.trial))

    # Generate prompt domain and evaluate on validation set
    if args.candidate_method == "induction":
        path_ = f"./query/{task}_{args.n_domain}"
    elif args.candidate_method == "rephrase":
        path_ = f"./query/{task}_{args.n_domain}_rephrase"

    if os.path.exists(path_):
        with open(path_, 'r') as fp:
            domains = json.load(fp)
            init_instructions = domains['instructions']
    else:
        if not os.path.exists("../query"):
            os.mkdir("../query")
        init_instructions = model_forward_api.initialize_prompts(args.n_domain, task, args.candidate_method)
        with open(path_, 'x') as fp:
            domains = {"instructions": init_instructions}
            json.dump(domains, fp, indent=4)

    sen_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2")
    sen_model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2")

    sen_embeddings = get_sen_embedding(sen_model, sen_tokenizer, init_instructions)
    sen_embeddings = sen_embeddings.to(**tkwargs)

    name = task + '_' + str(n_domain) + '.json'
    file_path = "./query_score/" + actual_model + "/" + name
    target_list_key = task
    print('file_path',file_path)
    if not os.path.exists("./query_score"):
        os.makedirs("./query_score")
    if not os.path.exists(file_path):
        with open(file_path, "w") as f:
            json.dump({}, f)

    with open(file_path, "r") as f:
        try:
            datas = json.load(f)
        except json.JSONDecodeError:
            datas = {}

    if target_list_key not in datas:
        score_list = []
        for i in range(n_domain):
            score_test = model_forward_api.eval([init_instructions[i]])
            score_list.append(float(round(score_test, 2)))

        datas[target_list_key] = score_list
        with open(file_path, "w") as f:
            json.dump(datas, f, indent=4)

    List_score = datas.get(target_list_key)
    List_score = np.array(List_score)

    test_num = 40
    all_tmp_scores = []
    for tmp in range(test_num):
        prompt_tmp = np.random.choice(range(n_domain))
        score_tmp = List_score[prompt_tmp]
        all_tmp_scores += [score_tmp]
    model_forward_api.score_mean = np.mean(all_tmp_scores)
    model_forward_api.score_std = np.std(all_tmp_scores)
    model_forward_api.score_min = np.min(all_tmp_scores)
    model_forward_api.score_max = np.max(all_tmp_scores)


    contexts = [[] for _ in range(num_clients)]
    score_list = [[] for _ in range(num_clients)]
    now_values = [[] for _ in range(num_clients)]
    arm2_scores = [[] for _ in range(num_clients)]
    common_ratio = 0
    total_samples = n_domain / 5
    common_count = int(total_samples * common_ratio)
    unique_count = total_samples - common_count

    torch.manual_seed(8864+10*trial)

    shared_indices = torch.randperm(sen_embeddings.size(0))[:common_count]

    for i in range(num_clients):
        seed = 66 + i + 10*trial
        torch.manual_seed(seed)

        all_indices = torch.randperm(sen_embeddings.size(0))
        unique_indices = []

        for idx in all_indices:
            if idx not in shared_indices:
                unique_indices.append(idx)
            if len(unique_indices) >= unique_count:
                break

        client_sample = torch.cat([shared_indices, torch.tensor(unique_indices)], dim=0)
        sample = client_sample.tolist()
        contexts[i] = sen_embeddings[sample]
        score_list[i] = List_score[sample]

    average_score = max(score_list[0])

    feature_dim = 768
    num_arms = len(contexts[0])

    env = FederatedLinearDuelingBanditEnvironment(feature_dim, num_arms, num_clients, noise)
    clients = [FederatedClient_optimal(client_id=i, feature_dim=feature_dim, learning_rate=learning_rate,
                                       lambda_reg=lambda_reg) for i in range(num_clients)]
    server = CentralServer_optimal(feature_dim, num_clients, learning_rate, lambda_reg)

    cumulative_regret = []
    total_regret = 0

    theta_global = torch.zeros(feature_dim, device="cuda")
    theta_locals = []
    for i in range(num_clients):
        torch.manual_seed(i+10*trial)
        theta_locals.append(0.1 * torch.randn(feature_dim, device='cuda'))
    client_grad_histories = [torch.zeros(feature_dim, device="cuda") for _ in range(num_clients)]

    best_instruction_over_iter = [[] for _ in range(num_clients)]
    selected_instruction_over_iter = [[] for _ in range(num_clients)]
    best_values = []

    rewards = [[] for _ in range(num_clients)]
    acc_rewards = [[0] for _ in range(num_clients)]
    best_rewards = [[0] for _ in range(num_clients)]
    selected_rewards = [[] for _ in range(num_clients)]
    max_scores = [max(score_list[i]) for i in range(num_clients)]

    for t in tqdm(range(1, num_iterations + 1)):
        print("iter:", t)

        w_updates = []

        for i, client, context in zip(range(num_clients), clients, contexts):
            beta_t = torch.sqrt(2 * torch.log(torch.tensor(1.0 / delta, device="cuda")) + feature_dim * torch.log(
                torch.tensor(1 + t * num_clients / (lambda_reg * feature_dim))))

            arm1_idx, arm2_idx = client.select_arms(theta_locals[i], context, beta_t, server.W_sync, score_list[i])
           
            print('arm1',arm1_idx,score_list[i][arm1_idx])
            print('arm2',arm2_idx,score_list[i][arm2_idx])
            arm1, arm2 = context[arm1_idx], context[arm2_idx]
            Max_score = max(score_list[i])

            winner = env.get_preference(arm1_idx, arm2_idx, score_list[i])

            if winner == 1:
                selected_arm_idx = arm1_idx
                selected_arm_reward = score_list[i][arm1_idx]
            else:
                selected_arm_idx = arm2_idx
                selected_arm_reward = score_list[i][arm2_idx]

            client.history_arm_1.append(arm1)
            client.history_arm_2.append(arm2)
            client.history_winner.append(winner)
            client.history_selected_arm.append(selected_arm_idx)
            client.history_arm_1_index.append(arm1_idx)
            client.history_arm_2_index.append(arm2_idx)

            w_updates.append(client.compute_w_update(arm1, arm2))

            r1 = score_list[i][arm1_idx]
            r2 = score_list[i][arm2_idx]
            now_values[i] += [r1]
            arm2_scores[i] += [r2]

            rewards[i].append(r1)
            selected_rewards[i].append(selected_arm_reward)
            acc_rewards[i].append(acc_rewards[i][-1] + r1)

            if r1 >= client.best_r1:
                client.best_r1 = r1
                client.best_arm1_idx = arm1_idx

            if r1 > best_rewards[i][-1]:
                best_rewards[i].append(r1)
            else:
                best_rewards[i].append(best_rewards[i][-1])

            if selected_arm_reward >= client.best_r:
                client.best_r = selected_arm_reward
                client.best_arm = selected_arm_idx

            best_instruction_over_iter[i] += [(t, init_instructions[client.best_arm1_idx])]
            selected_instruction_over_iter[i] += [(t, init_instructions[selected_arm_idx])]
            best_values.append(client.best_r1)

            theta_locals[i], client_grad_histories[i] = client.local_train_feddyn(
                theta_local=theta_locals[i],
                theta_global=theta_global,
                grad_history=client_grad_histories[i])

            print("agent:", i)
            print("arm:", arm1_idx, arm2_idx)
            print("reward:", r1, r2)
            print("best_arm:", client.best_arm)
            print(f"Best value found till now: {client.best_r}/{Max_score}")

        server.aggregate_w_updates(w_updates)
        theta_global = server.update_global_feddyn(theta_locals, client_grad_histories)

        cumulative_regret.append(total_regret)

    print("First row is arm1, second row is arm2, third row is arm1's reward")
    for i, client in zip(range(num_clients), clients):
        scores1 = [float(x) for x in now_values[i]]
        scores2 = [float(x) for x in arm2_scores[i]]

        print("agent:", i, client.history_arm_1_index)
        print("agent:", i, client.history_arm_2_index)
        print("score1:", i, scores1 )
        print("score2:", i, scores2 )
        print("preference:", i, client.history_winner)

    prompts = []
    prompts_set = {}
    for i, client in enumerate(clients):
        if client.best_arm is not None:
            best_prompt = init_instructions[client.best_arm]
            prompts.append(best_prompt)
            prompts_set[best_prompt] = client.best_r

    test_score = 0.0

    return test_score, prompts, prompts_set, best_values, now_values, best_instruction_over_iter, selected_instruction_over_iter, init_instructions, rewards, acc_rewards, best_rewards, selected_rewards, max_scores, average_score


def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description="FedPOB-Pref method")
    parser.add_argument("--task", type=str, default='larger_animal', help="Task name")
    parser.add_argument("--n_prompt_tokens", default=5, help="The number of prompt tokens.")
    parser.add_argument("--delta", type=float, default=0.1, help="Set the parameter delta.")
    parser.add_argument("--lambda_reg", type=float, default=1, help="Set the lambda parameter.")
    parser.add_argument("--noise", type=int, default=10, help="Set the noise parameter.")
    parser.add_argument("--n_domain", type=int, default=500, help="Set the number of domain.")
    parser.add_argument("--num_iterations", type=int, default=50, help="Set the number of total queries.")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="Set the learning rate.")
    parser.add_argument("--random_proj", type=str, default='uniform', help="Set the projection method.")
    parser.add_argument("--intrinsic_dim", type=int, default=100, help="Set the number of intrinsic dim.")
    parser.add_argument("--n_eval", type=int, default=1000, help="Set the number of domains to be evaluated.")
    parser.add_argument("--gpt", type=str, default="gpt-3.5-turbo", help="Which model to use.")
    parser.add_argument("--init_scale", type=float, default=1, help="Which scale to use.")
    parser.add_argument("--pooling", type=str, default="last", help="Which pooling method to use.")
    parser.add_argument("--trial", type=int, default=0, help="Trial ID.")
    parser.add_argument("--magnitude", type=int, default=10, help="The magnitude of the scores.")
    parser.add_argument("--norm_method", type=str, default='standard', help="The way to transform the value.")
    parser.add_argument("--candidate_method", type=str, default='induction', help="The way to generate candidates.")
    parser.add_argument("--num_clients", type=int, default=1, help="The number of agents.")
    parser.add_argument("--name", type=str, default="", help="Set the name of the experiments.")

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    """
    Main execution script for FedPOB-Pref experiments.

    Runs experiments across multiple tasks and agent configurations,
    saving results to JSON files for analysis.
    """
    import time
    import json
    import os
    args = parse_args()

    run_tasks = [
    'boolean_expressions', 'date_understanding', 'disambiguation_qa','dyck_languages', 'formal_fallacies',
    'geometric_shapes', 'hyperbaton',
    'logical_deduction_five_objects', 'logical_deduction_seven_objects', 'logical_deduction_three_objects',
    'movie_recommendation', 'multistep_arithmetic_two', 'navigate',
    'penguins_in_a_table', 'reasoning_about_colored_objects', 'ruin_names',
    'salient_translation_error_detection', 'snarks', 'sports_understanding', 'temporal_sequences',
    'tracking_shuffled_objects_five_objects', 'tracking_shuffled_objects_seven_objects',
    'tracking_shuffled_objects_three_objects','web_of_lies']
    run_agents = [1,3,10]

    for run_task in run_tasks:
        for run_agent in run_agents:
            args.task = run_task
            args.num_clients = run_agent

            print(f"\n{'='*50}")
            print(f"Running FedPOB-Pref method for task: {args.task}, agents: {args.num_clients}")
            print(f"Model: {args.gpt}")
            print(f"{'='*50}")

            log_file_name = get_result_file_path(args, run_task)
            result_folder = os.path.dirname(log_file_name)

            print(f"Results will be saved to: {log_file_name}")

            all_experiments = []
            ave_scores = []

            for experiment_id in range(5):
                print(f"\nRunning experiment {experiment_id+1}/5...")
                start_time = time.time()

                try:
                    test_scores, prompts, prompts_set, best_values, now_values, best_instruction_over_iter, selected_instruction_over_iter, init_instructions, rewards, acc_rewards, best_rewards, selected_rewards, max_scores, average_score= run_FedPOB_Pref(
                        task=args.task,
                        n_prompt_tokens=args.n_prompt_tokens,
                        delta=args.delta,
                        lambda_reg=args.lambda_reg,
                        noise=args.noise,
                        n_domain=args.n_domain,
                        num_iterations=args.num_iterations,
                        learning_rate=args.learning_rate,
                        random_proj=args.random_proj,
                        intrinsic_dim=args.intrinsic_dim,
                        n_eval=args.n_eval,
                        gpt=args.gpt,
                        init_scale=args.init_scale,
                        pooling=args.pooling,
                        num_clients=args.num_clients,
                        trial=experiment_id,
                        args=args
                    )

                    end_time = time.time()

                    ave_scores.append(average_score)

                    experiment_result = {
                        'experiment_id': experiment_id,
                        'task': args.task,
                        'agents': args.num_clients,
                        'num_iterations': args.num_iterations,
                        'delta': args.delta,
                        'lambda_reg': args.lambda_reg,
                        'noise': args.noise,
                        'learning_rate': args.learning_rate,
                        'gpt': args.gpt,
                        'test_scores': test_scores,
                        'prompts': prompts,
                        'prompts_set': prompts_set,
                        'best_values': best_values,
                        'now_values': now_values,
                        'best_instruction_over_iter': best_instruction_over_iter,
                        'selected_instruction_over_iter': selected_instruction_over_iter,
                        'init_instructions': init_instructions,
                        'rewards': rewards,
                        'acc_rewards': acc_rewards,
                        'best_rewards': best_rewards,
                        'selected_rewards': selected_rewards,
                        'max scores':max_scores,
                        'runtime': end_time - start_time,
                        'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
                    }

                    all_experiments.append(experiment_result)

                    print(f"Experiment {experiment_id+1} completed successfully!")
                    print(f"Runtime: {end_time - start_time:.2f} seconds")
                    if rewards and len(rewards[0]) > 0:
                        print(f"Final arm1 reward: {rewards[0][-1]:.4f}")
                        print(f"Best arm1 reward: {best_rewards[0][-1]:.4f}")
                        print(f"Final selected arm reward: {selected_rewards[0][-1]:.4f}")
                        print(f"Total accumulated reward (arm1): {acc_rewards[0][-1]:.4f}")

                except Exception as e:
                    print(f"Experiment {experiment_id+1} failed with error: {e}")
                    import traceback
                    traceback.print_exc()
                    continue
            average = sum(ave_scores) / len(ave_scores)
            print(f"score: {average}")

            if all_experiments:
                with open(log_file_name, 'w', encoding='utf-8') as f:
                    json.dump(all_experiments, f, indent=2, ensure_ascii=False)
                print(f"\nResults saved to: {log_file_name}")
                print(f"Completed {len(all_experiments)}/5 experiments for {args.task} with {args.num_clients} agents")
            else:
                print(f"No successful experiments for {args.task} with {args.num_clients} agents")

    print(f"\n{'='*50}")
    print("All FedPOB-Pref experiments completed!")

