import os
import sys

import argparse
import copy
import json
import pandas as pd
import sys

from tqdm import tqdm

from parse_string import LlamaParser
from agents import AgentAction, HuggingfaceChatbot
from utils import *
import random
import numpy as np
import torch


def set_seeds(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

def data_loader_costumized(dataset_name):
    if dataset_name =='allenai/wildguardmix':
        dataset = load_dataset(dataset_name, 'wildguardtest')['test']
    else:
        if dataset_name.split('.')[-1] == 'parquet':
            pd_file = pd.read_parquet(dataset_name)
            dataset = Dataset.from_pandas(pd_file)
            def make_map_fn_process_prompt(example):
                example['prompt_with_template'] = example['prompt']
                example['prompt'] = example['prompt'][0]['content']
                return example
            dataset = dataset.map(make_map_fn_process_prompt)
        else:
            dataset = load_from_disk(dataset_name)
    return dataset

def main(args):
    set_seeds(args)
    log(str(args),args.log_path)
    dataset = data_loader_costumized(args.dataset_name)
    if args.api_model == 'deepseek-reasoner':
        parser_fn = LlamaParser().parse_response_reasoner
    else:
        parser_fn = LlamaParser().parse_response
    if args.api_name:
        chatbot = ''
    else:    
        chatbot = HuggingfaceChatbot(args.model)
    agents = AgentAction(chatbot, 
                         parser_fn = parser_fn,
                         template = args.prompt_template,
                         **vars(args))
    result_save_path = args.log_path.replace('.txt', '_results.txt')
    
    save_list = []
   
    for i, cur_data in enumerate(tqdm(dataset)):
        if 'ai_act' in args.dataset_name and i < 290:
            continue
        elif 'gdpr' in args.dataset_name and i < 315:
            continue
        prompt_content = cur_data['prompt']
        data_temp = cur_data
       
        log(str(f"=== prompt id: {i} ===\n"), args.log_path)
        for _ in range(args.generation_round):
            try:
                if args.api_model == 'deepseek-reasoner':
                    response, response_reasoning = agents.complete(prompt_content=prompt_content, **vars(args)) 
                    log(str(f"--Reasoning:--\n"), args.log_path)
                    log(str(response_reasoning)+"\n", args.log_path)
                    log(str(f"--Response:--\n"), args.log_path)
                    log(str(response)+"\n", args.log_path)   

                else:
                    response = agents.complete(prompt_content=prompt_content, **vars(args))  
                    log(str(response)+"\n", args.log_path)         
                      
                # log(str(f"sample_id: {i} --- result:{result} --- answer: {answer}\n"), args.log_path)
                
                if response: break

            except Exception as e:
                print(e)
                continue

        data_temp['response_from_api_model_safety_compliance_reply'] = response
        if args.api_model == 'deepseek-reasoner':
            data_temp['response_from_api_model_safety_compliance_reply_reasoning'] = response_reasoning
        
        save_list.append(data_temp)
        if i % 10 == 0:
            with open(args.log_json_path, 'w') as file:
                json.dump(save_list, file, indent=4)

    log(str(f"num_sample: {len(dataset)}\n"), args.log_path)
    log(str(f"num_sample: {len(dataset)}\n"), result_save_path)
    with open(args.log_json_path, 'w') as file:
        json.dump(save_list, file, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    #parser.add_argument("--model", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
    parser.add_argument("--model", type=str, default="")
    parser.add_argument("--log_path", type=str, default=".save_labeling/log.txt")
    parser.add_argument("--log_json_path", type=str, default=".save_labeling/ai_act.json")
    parser.add_argument("--prompt_template", type=str, default="prompts/empty.txt")
    parser.add_argument("--max_new_tokens", type=int, default=4096)

    parser.add_argument("--generation_round", type=int, default=10)
    parser.add_argument("--max_law_items", type=int, default=2)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--api_name", type=str, default='')
    ### newly appeneded
    # parser.add_argument("--domains", type=str, default='AI_ACT+GDPR+HIPAA+ACLU')
    parser.add_argument("--api_model", type=str, default='deepseek-reasoner')
    parser.add_argument("--api_token", type=str, default='')
    parser.add_argument("--max_retry", type=int, default=5)
    parser.add_argument("--temperature", type=float, default=1.0)

    ## system prompt
    parser.add_argument("--system_prompt", type=str, default='You are a helpful assistant.')

    ## task
    # parser.add_argument("--task", type=str, default='case_eval')
    parser.add_argument("--dataset_name", type=str, default='ai_act/ai_act_prompt_prepared_08_24')

    args = parser.parse_args()
    main(args)
