import os
import json
import argparse
import logging

from symphony import mcts_search
from webshop import WebShopTask

# Configuring the logging
import math
from typing import List, Dict

class AgentPool:
    def __init__(self, llm_names: List[str], exploration_weight: float = 20.0):
        """
        inital agentpool

        Args:
            llm_names: agent pool
            exploration_weight: UCB α
        """
        self.llm_names = llm_names
        self.exploration_weight = exploration_weight

        self.llm_stats = {name: {"total_reward": 0.0, "call_count": 0} for name in llm_names}
        self.total_calls = 0

    def select_llm(self) -> str:
        """use the UCB to select the next LLM"""
        if self.total_calls < len(self.llm_names):
            # ensureall LLMs are called at least once
            for name in self.llm_names:
                if self.llm_stats[name]["call_count"] == 0:
                    return name

        # compute
        ucb_values = {}
        for name, stats in self.llm_stats.items():
            if stats["call_count"] == 0:
                continue
            exploitation = stats["total_reward"] / stats["call_count"]
            exploration = math.sqrt(
                self.exploration_weight * math.log(self.total_calls) / (stats["call_count"] + 1)
            )
            ucb_values[name] = exploitation + exploration

        return max(ucb_values, key=ucb_values.get)

    def call_llm(self) -> str:

        llm_name = self.select_llm()

        self.llm_stats[llm_name]["call_count"] += 1
        self.total_calls += 1

        return llm_name

    def update_reward(self, llm_name: str, reward: float) -> None:

        if llm_name not in self.llm_names:
            raise ValueError(f"LLM {llm_name} not")

        self.llm_stats[llm_name]["total_reward"] += reward

    def get_llm_stats(self) -> Dict[str, Dict[str, float]]:
        return self.llm_stats

    def get_average_reward(self, llm_name: str) -> float:

        if llm_name not in self.llm_names:
            raise ValueError(f"LLM {llm_name} no")

        count = self.llm_stats[llm_name]["call_count"]
        if count == 0:
            return 0.0

        return self.llm_stats[llm_name]["total_reward"] / count


def run(args):
    task = WebShopTask()
    print(task)
    logs, cnt_avg, cnt_any = [], 0, 0
    
    logging.basicConfig(filename=args.log, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filemode='a')

    count = 0
    task_accs = []
    info = []
    n = args.task_end_index

    for i in range(args.task_start_index, args.task_end_index):
        # if i == 0:
        #     continue
        # solve
        # 2_4_13_27_28
        # if i == 2 or i == 4 or i == 13 or i == 27 or i == 28:
        state, value, reward, em, question = mcts_search(args, task, f'fixed_{i}', args.iterations, True)
        from models import gpt_usage
     # log main metric
    # task_accs.append(em)
        if state is None:
            value =0
            reward = 0
            em = 0
            question = "error"
        print("best reward", reward)
        # cnt_avg = sum(task_accs) / len(task_accs)
        # print(i, 'len(task_accs)', len(task_accs), 'cnt_avg', cnt_avg, '\n')
        task_accs.append(reward)
        if (i+1) % 1 == 0:
            r, sr, fr = sum(task_accs) / len(task_accs), len([_ for _ in task_accs if _ == 1]) / len(task_accs), count / len(task_accs)
            print(i+1, r, sr, fr)
            result = f"The current {i} task: {question} \ n, value is {value}, reward score is {reward}, em is {em} \ n The detailed information of the result of the "+f" task is: {i}, average score r: {r}, proportion of completely successful SR: {sr}"
            """
                Write the task results to the specified file.
            """
            log1 = "./log1/"  # define your path   todo
            with open(log1 + "runname.txt",   #TODO
                      "a") as file:
                file.write(result)
                file.write(f'{i} usage :{gpt_usage()}' + "\n ------------------------- \n")
            print("OK")
            print('-------------')
        r, sr, fr = sum(task_accs) / len(task_accs), len([_ for _ in task_accs if _ == 1]) / n, count / n
        print(r, sr, fr)

        logging.info(f"RESULTS: {r}, {sr}, {fr}")
       
    n = args.task_end_index - args.task_start_index
    print('usage_so_far', gpt_usage(args.backend))

# def parse_args():
#     args = argparse.ArgumentParser()
#     args.add_argument('--backend', type=str, choices=['gpt-4', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', 'llama2', "text-davinci-002"], default='gpt-3.5-turbo-16k')
#     args.add_argument('--temperature', type=float, default=1.0)
#     args.add_argument('--task_start_index', type=int, default=900)
#     args.add_argument('--task_end_index', type=int, default=1000)
#     args.add_argument('--prompt_sample', type=str, choices=['standard', 'cot'])
#     args.add_argument('--n_generate_sample', type=int, default=1)  # only thing needed if naive_run
#     args.add_argument('--n_evaluate_sample', type=int, default=1)
#     args.add_argument('--iterations', type=int, default=30)
#     args.add_argument('--log', type=str)
#
#     args = args.parse_args()
#     return args



class Args:
    def __init__(self, backend = "Qwen2.5-7B-Instruct-1M", temperature=1.0, task_start_index=900, task_end_index=1000, prompt_sample = 'cot', n_generate_sample=1, n_evaluate_sample=1, iterations=30, log = ""):
        self.backend = backend
        self.temperature = temperature
        self.task_start_index = task_start_index
        self.task_end_index = task_end_index
        self.prompt_sample = prompt_sample
        self.n_generate_sample = n_generate_sample
        self.n_evaluate_sample = n_evaluate_sample
        self.iterations = iterations
        self.log = log

    def __getattr__(self, name):
        raise AttributeError(f"{name} attribute does not exist")


def get_args():
    args = Args(
        backend = "Qwen2.5-7B-Instruct-1M", # Qwen2___5-7B-Instruct   Qwen2.5-14B-Instruct
        temperature=0.2,
        task_start_index=0,
        task_end_index=100,
        prompt_sample = 'cot',
        n_generate_sample=4,
        n_evaluate_sample=1,
        iterations=10,
        log = "runname.log", # define your path TODO
       )


    return args

llm_manager = AgentPool(["Qwen2.5-7B-Instruct-1M", "Mistral-7B-Instruct-v0.3","Meta-Llama-3.1-8B-Instruct"])

if __name__ == '__main__':
    # args = parse_args()
    args = get_args()
    print(args)
    run(args)