import os
import json
import argparse

from hotpotqa import HotPotQATask
from models import gpt_usage
from symphony import mcts_search
from tot import dfs_search
from rap import mcts_search
import logging
from typing import List, Dict
import math
class AgentPool:

    def __init__(self, llm_names: List[str], exploration_weight: float = 20.0):  # setting the α is 20
        """
        Args:
            llm_names: agent pool
            exploration_weight: α）
        """
        self.llm_names = llm_names
        self.exploration_weight = exploration_weight

        self.llm_stats = {name: {"total_reward": 0.0, "call_count": 0} for name in llm_names}
        self.total_calls = 0

    def select_llm(self) -> str:
        """UCB to select next agent"""
        if self.total_calls < len(self.llm_names):
            # ensure every LLM have been called at least once
            for name in self.llm_names:
                if self.llm_stats[name]["call_count"] == 0:
                    return name

        ucb_values = {}
        for name, stats in self.llm_stats.items():
            if stats["call_count"] == 0:
                continue
            # compute UCB:
            exploitation = stats["total_reward"] / stats["call_count"]
            # stats["call_count"] + 1
            exploration = math.sqrt(
                self.exploration_weight * math.log(self.total_calls) / (stats["call_count"] + 1)
            )
            ucb_values[name] = exploitation + exploration

        # return the highest UCB
        return max(ucb_values, key=ucb_values.get)

    def call_llm(self) -> str:

        llm_name = self.select_llm()

        # update
        self.llm_stats[llm_name]["call_count"] += 1
        self.total_calls += 1

        return llm_name

    def update_reward(self, llm_name: str, reward: float) -> None:
        if llm_name not in self.llm_names:
            raise ValueError(f"LLM {llm_name} no")

        self.llm_stats[llm_name]["total_reward"] += reward

    def get_llm_stats(self) -> Dict[str, Dict[str, float]]:
        return self.llm_stats

    def get_average_reward(self, llm_name: str) -> float:
        if llm_name not in self.llm_names:
            raise ValueError(f"LLM {llm_name} no")

        count = self.llm_stats[llm_name]["call_count"]
        if count == 0:
            return 0.0

        return self.llm_stats[llm_name]["total_reward"] / count

def run(args):
    task = HotPotQATask()
    print(task)
    logs, cnt_avg, cnt_any = [], 0, 0

    # create log directories if they don't exist
    # os.makedirs(os.path.dirname(args.log), exist_ok=True)
    
    # logging.basicConfig(filename=args.log, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='a')

    log_dir = os.path.dirname(args.log)
    print(f"Log directory: {log_dir}")
    os.makedirs(log_dir, exist_ok=True)

    log_path = os.path.abspath(args.log)
    print(f"Log file path: {log_path}")

    logging.basicConfig(
        filename=args.log,
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        filemode='w'
    )
    logging.info("Logging configuration is successful.")

    count = 0
    task_accs = []
    info = []

    for i in range(args.task_start_index, args.task_end_index):
        if i >= 0:
            if args.algorithm == 'mcts':
                state, value, all_nodes, reward, em, question = mcts_search(args, task, i, args.iterations, True)
            elif args.algorithm == 'tot':
                state, value, all_nodes, reward, em = dfs_search(args, task, i, args.iterations)
            elif args.algorithm == 'rap':
                state, value, all_nodes, reward, em = mcts_search(args, task, i, args.iterations)
            else:
                raise Exception("Search algorithm option not valid")
             # log main metric


            if em is None:
                em = 0

            task_accs.append(em)
            cnt_avg = sum(task_accs) / len(task_accs)
            print(i, 'len(task_accs)', len(task_accs), 'cnt_avg', cnt_avg, '\n')
            #all_nodes_dict = [(node.to_dict(), value) for node, value in all_nodes]

            result = f"The current {i} th task: {question} \ n, value is {value}, reward score is {reward}, em is {em} \ n The detailed information of the "+f" task is: {i}, len (task_accs): {len (task_accs)}, cnt.avg: {cnt_avg}"

            """
                Write the task results to the specified file.
            """
            log1 = "./log1/"
            with open(log1 + "runname.txt", "a") as file:
                file.write(result)
                file.write(f'cuurent task{i}  consume :{gpt_usage()}' + "\n ------------------------- \n" )
            print(f"The task result has been written to the file")

            print( f'task{i}', gpt_usage())
            logging.info(f'task{i}:{gpt_usage()}')

       
    n = args.task_end_index - args.task_start_index
    print('usage_so_far', gpt_usage())
    logging.info(f'usage_so_far :{gpt_usage()}')
# def parse_args():
#     args = argparse.ArgumentParser()
#     args.add_argument('--backend', type=str, choices=['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613'], default='gpt-3.5-turbo-0613')
#     args.add_argument('--temperature', type=float, default=1.0)
#     args.add_argument('--task_start_index', type=int, default=900)
#     args.add_argument('--task_end_index', type=int, default=1000)
#     args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot'])
#     args.add_argument('--n_generate_sample', type=int, default=1)
#     args.add_argument('--n_evaluate_sample', type=int, default=1)
#     args.add_argument('--iterations', type=int, default=50)
#     args.add_argument('--log', type=str)
#     args.add_argument('--algorithm', type=str, choices=  ['lats', 'rap', 'tot'])
#
#     args = args.parse_args()
#     return args




class Args:
    def __init__(self, backend = "Qwen2___5-7B-Instruct", temperature=1.0, task_start_index=900, task_end_index=1000, prompt_sample = 'cot', n_generate_sample=1, n_evaluate_sample=1, iterations=50, log = "", algorithm = "mcts"):
        self.backend = backend
        self.temperature = temperature
        self.task_start_index = task_start_index
        self.task_end_index = task_end_index
        self.prompt_sample = prompt_sample
        self.n_generate_sample = n_generate_sample
        self.n_evaluate_sample = n_evaluate_sample
        self.iterations = iterations
        self.log = log
        self.algorithm = algorithm

    def __getattr__(self, name):
        raise AttributeError(f"{name} attribute does not exist")


def get_args():
    args = Args(
        backend = "Qwen2.5-7B-Instruct-1M", # Qwen2___5-7B-Instruct   Qwen2.5-14B-Instruct
        temperature=0.2,
        task_start_index=1,
        task_end_index=100,
        prompt_sample = 'cot',
        n_generate_sample=4,
        n_evaluate_sample=1,
        iterations=10,
        log = "runname.log",
        algorithm = "mcts"    )


    return args


llm_manager = AgentPool(["Qwen2.5-7B-Instruct-1M", "Mistral-7B-Instruct-v0.3" , "Meta-Llama-3.1-8B-Instruct"])

if __name__ == '__main__':
    # args = parse_args()
    args = get_args()
    print(args)
    run(args)