import json
import sys
from tqdm import tqdm
from copy import deepcopy
import os
sys.path.append('../environments')
from redirect_output import redirect_output
from EnvManager import EnvManager
from model_api import *
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--model_name", type=str, default="qwen2.5-14b")
parser.add_argument('--greedy', type=int, default=1)
parser.add_argument("--regen_exceed", type=int, default=0)
parser.add_argument('--extra_info', type=str, default='')
parser.add_argument('--allow_empty', type=int, default=0)


args = parser.parse_args()

env_manager = EnvManager()
# stop_token = '[FINISHED]'

model_name = args.model_name

if args.greedy:
    print("Use greedy decoding!")
    api_model_generation_config = {
        "temperature": 0.,
        "max_tokens": 8192
    }
    local_model_generation_config = {
        "do_sample": False,
        "repetition_penalty": 1.0,
        "num_beams": 1,
        "temperature": 1.0,
        "top_k": 0,
        "top_p": 1.0,
        "use_cache": True,
        "max_new_tokens": 2048,
    }
else:
    print("Use random sampling!")
    api_model_generation_config = {
        "temperature": 1.0,
        "max_tokens": 8192
    }
    local_model_generation_config = {
        "do_sample": True,
        # "repetition_penalty": 1.0,
        "num_beams": 1,
        # "temperature": 1.0,
        # "top_k": 0,
        # "top_p": 1.0,
        "use_cache": True,
        "max_new_tokens": 2048,
    }

if model_name == 'gpt4omini':
    agent_api = OpenaiAPI(model_name='gpt-4o-mini', generation_config=api_model_generation_config)
elif model_name == 'gpt4o':
    agent_api = OpenaiAPI(model_name='gpt-4o', generation_config=api_model_generation_config)
elif model_name == 'gpt4-turbo':
    agent_api = OpenaiAPI(model_name='gpt-4-turbo', generation_config=api_model_generation_config)
elif model_name == 'gpt-5':
    agent_api = OpenaiAPI(model_name='gpt-5', generation_config=api_model_generation_config)
elif model_name == 'claude3-opus':
    agent_api = ClaudeAPI(model_name='anthropic/claude-3-opus', generation_config=api_model_generation_config)
elif model_name == 'claude3.5-sonnet':
    agent_api = ClaudeAPI(model_name='anthropic/claude-3.5-sonnet', generation_config=api_model_generation_config)
elif model_name == 'claude3.5-haiku':
    agent_api = ClaudeAPI(model_name='anthropic/claude-3.5-haiku', generation_config=api_model_generation_config)
elif model_name == 'gemini1.5-pro':
    agent_api = GeminiAPI(model_name='google/gemini-pro-1.5', generation_config=api_model_generation_config)
elif model_name == 'gemini1.5-flash':
    agent_api = GeminiAPI(model_name='google/gemini-flash-1.5', generation_config=api_model_generation_config)
elif model_name == 'qwen2.5-72b':
    agent_api = QwenCloudAPI(model_name='qwen/qwen-2.5-72b-instruct', generation_config=api_model_generation_config)
elif model_name == 'deepseekv2.5':
    agent_api = DeepseekAPI(model_name='deepseek/deepseek-chat', generation_config=api_model_generation_config)
elif model_name == 'llama3.1-70b':
    agent_api = LlamaCloudAPI(model_name='meta-llama/llama-3.1-70b-instruct', generation_config=api_model_generation_config)
elif model_name == 'llama3.1-405b':
    agent_api = LlamaCloudAPI(model_name='meta-llama/llama-3.1-405b-instruct', generation_config=api_model_generation_config)

# local model
elif model_name == 'qwen2.5-7b':
    model_path = "Qwen/Qwen2.5-7B-Instruct"
    agent_api = QwenAPI(model_path=model_path, tokenizer_path=model_path, generation_config=local_model_generation_config)
elif model_name == 'qwen2.5-14b':
    model_path = "Qwen/Qwen2.5-14B-Instruct"
    agent_api = QwenAPI(model_path=model_path, tokenizer_path=model_path, generation_config=local_model_generation_config)
elif model_name == 'llama3.1-8b':
    model_path = "meta-llama/Llama-3.1-8B-Instruct"
    agent_api = Llama3API(model_path=model_path, tokenizer_path=model_path, generation_config=local_model_generation_config)
elif model_name == 'glm4-9b':
    model_path = "THUDM/glm-4-9b-chat"
    agent_api = GLM4API(model_path=model_path, tokenizer_path=model_path, generation_config=local_model_generation_config)
else:
    raise ValueError(f"Model {model_name} not found.")

error_log_path = './error_samples.json'
exceed_log_path = './exceed_round_samples.json'

def parse_envs(envs_info):
    envs = []
    tool_descs = []
    for env_info in envs_info:
        env_name = env_info['name']
        if 'parameters' in env_info:
            env_params = env_info['parameters'] or None
        else:
            env_params = None
        env = env_manager.init_env(env_name, env_params)
        if env is None:
            print(envs_info)
            raise ValueError(f"Environment {env_name} not found.")
        envs.append(env)
        
        tool_names = env_info['tools']
        tool_desc = env.get_tool_descs(tool_names)
        tool_descs.extend(tool_desc)
        
    return envs, tool_descs

def remove_exceed_samples(outpath):
    try:
        with open(outpath, 'r', encoding='utf-8') as f:
            data = json.load(f)
            
        with open(exceed_log_path, 'r', encoding='utf-8') as f:
            exceed_samples = json.load(f)
    except Exception as e:
        # print(e)
        return
    
    exceed_ids = set([d['info']['id'] for d in exceed_samples])
    new_data = []
    for d in data:
        if d['id'] in exceed_ids:
            continue
        new_data.append(d)
        
    print(f'remove {len(data) - len(new_data)} exceed samples')
    
    with open(outpath, 'w', encoding='utf-8') as f:
        json.dump(new_data, f, indent=1, ensure_ascii=False)

def log_error_samples(messages, info):
    try:
        with open(error_log_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except Exception as e:
        data = []
    data.append({'messages': messages, 'info': info})
    with open(error_log_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)


def log_exceed_round_samples(messages, info):
    try:
        with open(exceed_log_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except Exception as e:
        data = []
    data.append({"messages": messages, "info": info})
    with open(exceed_log_path, "w", encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def clear_log(clear_error=True, clear_exceed=True):
    if clear_error:
        try:
            with open(error_log_path, 'w', encoding='utf-8') as f:
                json.dump([], f, indent=2)
        except:
            pass
    if clear_exceed:
        try:
            with open(exceed_log_path, 'w', encoding='utf-8') as f:
                json.dump([], f, indent=2)
        except:
            pass

def evaluate(d):
    envs = []
    tool_descs = []

    max_rounds = 50
    rounds = 0
    final_result = None

    info = {"id": d["id"]}
    try:
        messages = [{"role": "system", "content": agent_api.get_system_prompt(d)}]

        if "dialog" in d:
            messages.extend(d["dialog"])
        else:
            messages.append({"role": "user", "content": d["instruction"]})

        if d['environments'] and d['environments'][0]['name'] != '':
            envs, tool_descs = parse_envs(d['environments'])

            for item in tool_descs:
                item['type'] = "object"

            for i, item in enumerate(tool_descs):
                newitem = {'type': 'function', 'function': item}
                tool_descs[i] = newitem

        while rounds < max_rounds:
            if rounds > 8:
                print(f'current round: {rounds}, messages: {messages}')
            else:
                print(f'id {d["id"]}, current round: {rounds}')

            completion = agent_api.generate_response(messages, tool_descs)

            rounds += 1

            if completion is None or 'type' not in completion: continue
            if completion['type'] == 'error':
                print(completion['message'])
                continue

            if completion['type'] == 'tool':
                tool_call_id = completion['tool_call_id']
                tool_name = completion['tool_name']
                arguments = completion['arguments']

                tool_call_result = None
                for env_idx, env in enumerate(envs):
                    provided_tools = d['environments'][env_idx]['tools']
                    if tool_name not in provided_tools: continue
                    if not env.has_tool(tool_name): continue
                    tool_call_result = env.call_tool(tool_name, deepcopy(arguments))
                    break

                if tool_call_result is None:
                    # raise ValueError(f"Invalid tool call {tool_name}.")
                    # print(f"Invalid tool call {tool_name}.")
                    tool_call_result = {"success": False, "message": f"Tool {tool_name} doesn't exist."}
                    # continue

                # print('tool call result: ', tool_call_result)

                tool_call_response = {
                    "role": "assistant",
                    "tool_calls": [
                        {
                            "id": tool_call_id,
                            "type": "function",
                            "function": {
                                "arguments": json.dumps(arguments),
                                "name": tool_name
                            }
                        }
                    ]
                }

                tool_call_result_message = {
                    "role": "tool",
                    "content": json.dumps(tool_call_result),
                    "tool_call_id": tool_call_id,
                    "name": tool_name
                }

                messages.append(tool_call_response)
                messages.append(tool_call_result_message)

            ## normal content part
            else:
                content = completion['content']
                messages.append({
                    'role': 'assistant',
                    'content': content
                })
                # if stop_token not in content: continue
                final_result = messages
                break

    except Exception as e:
        print(e)
        import traceback; traceback.print_exc();
        log_error_samples(messages, info)
        return None, rounds

    if final_result is None:
        if rounds == 10:
            final_result = messages
            log_exceed_round_samples(messages, info)
        else:
            log_error_samples(messages, info)
    else:
        if not args.allow_empty and (final_result[-1]['role'] != 'assistant' or final_result[-1]['content'] == ''):
            log_error_samples(messages, info)
            return None, rounds
            
    return final_result, rounds

def eval_file(path, outpath):
    basedir = os.path.dirname(outpath)
    os.makedirs(basedir, exist_ok=True)
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    if os.path.exists(outpath):
        try:
            with open(outpath, 'r', encoding='utf-8') as f:
                outdata = json.load(f)
        except Exception as e:
            print(e)
            outdata = []
    else:
        outdata = []
    
    _outdata = []
    for d in outdata:
        if 'label' in d and d['label'] == -3:
            # changed samples should be re-cd
            continue
        _outdata.append(d)
    outdata = _outdata
    
    outids = set([d['id'] for d in outdata])
    _data = []
    print("data: ", len(outids))
    for d in data:
        if d['id'] in outids:
            continue
        if 'finish' in d:
            if d['finish'] != 1:
                continue
        _data.append(d)
        
    data = _data
    success_count = 0
    fail_count = 0
    exceed_count = 0
    for d in tqdm(data):
        opt, rounds = evaluate(d)
        if opt is not None:
            success_count += 1
            if rounds == 10:
                print(f'Exceed max rounds! id: {d["id"]}')
                exceed_count += 1
            
            d['output'] = opt
            
            outdata.append(d)

            with open(outpath, 'w', encoding='utf-8') as fw:
                json.dump(outdata, fw, indent=2, ensure_ascii=False)
        else:
            fail_count += 1
            print(f'Fail! id: {d["id"]}')
    
    print(f'success count: {success_count} (exceed count: {exceed_count}), fail count: {fail_count}')


if __name__ == '__main__':

    path = "../data/released_data.json"
    outdir = f"./evaluation_results/tot-{model_name}"

    outdir += f'{args.extra_info}'
    outpath = f"{outdir}/gen_res.json"
    error_log_path = f"{outdir}/error_samples.json"
    exceed_log_path = f"{outdir}/exceed_round_samples.json"
    log_path = f"{outdir}/gen_log.txt"
    
    os.makedirs(outdir, exist_ok=True)

    redirect_output(log_path, 'w')
    regenerate_exceed_samples = args.regen_exceed
    if regenerate_exceed_samples:
        remove_exceed_samples(outpath)
        clear_log(clear_error=True, clear_exceed=True)
    else:
        clear_log(clear_error=True, clear_exceed=False)
    eval_file(path, outpath)
