import json
import re
import argparse
from transformers import LlamaTokenizer
import random
import sys
import copy
from collections import defaultdict
import numpy as np


action_list = ['click','hover','click_and_type','key_press','goto','go_back','go_forward','new_tab','close_tab','switch_tab','stop']

def convert_for_training(data, model_format):
    err_cnt = defaultdict(int)
    
    if model_format == "text_model":
        id = 0
        json_list = []

        for entry in data:
            json_dict = {}
            intent = entry["task"]
            next_actn = entry['next_action']
            axt = entry["axt"].replace('RootWebArea','')
            website = ""
            prev_actn = entry['prev_actions']
            prompt_str = f"""You are given an observation of a web page, an objective and past actions, your goal is to generate the next action given the current web page.
Here are the possible actions:
click: This action clicks on an element with a specific id on the webpage
type: Use this to type the content into the field with id
hover: Hover over an element with id
Website: {website}
Observation:
{axt}
Objective:
{intent}
Past actions:
{prev_actn}
"""
            
            json_dict["id"] = id
            id += 1
            json_dict["conversations"] = [
				{"from": "human", "value": prompt_str},
				{
					"from": "gpt",
					"value": next_actn,
				},
			]
            json_list.append({'prompt':copy.deepcopy(json_dict['conversations'][0]['value']), 
			'response': copy.deepcopy(json_dict['conversations'][1]['value'])})

    elif model_format == 'code_model':
        id = 0
        json_list = []

        for entry in data:
            json_dict = {}
            intent = entry["task"]
            next_actn = entry['next_action']
            if f"[{entry['axt_nodeid']}]" not in entry['axt'] and ('click' in next_actn or 'type' in next_actn or 'hover' in next_actn):
                err_cnt['nodeid not in tree'] += 1
                print(entry['axt'], entry['next_action'])
                continue
            axt = entry["axt"].replace('RootWebArea','')
            website = ""
            prev_actn = entry['prev_actions']
            prev_actn_lines = prev_actn.split('\n')
            bad_histroy = False
            bad_line = ''
            for i in range(len(prev_actn_lines)):
                if '<' in prev_actn_lines[i] and '>' in prev_actn_lines[i]:
                    bad_histroy = True
                    bad_line = prev_actn_lines[i]
                    err_cnt['<>'] += 1
                    break
                if '...' in prev_actn_lines[i] and len(prev_actn_lines[i]) < 10:
                    bad_histroy = True
                    bad_line = prev_actn_lines[i]
                    err_cnt['...'] += 1
                    break
                if i > 0 and '# step ' in prev_actn_lines[i-1]:
                    valid_action = False
                    for action in action_list:
                        if action in prev_actn_lines[i]:
                            valid_action = True
                            break
                    if not valid_action:
                        bad_line = prev_actn_lines[i]
                        err_cnt['invalid_action'] += 1
                        bad_histroy = True
                        break
                if '#' not in prev_actn_lines[i] and 'step' not in prev_actn_lines[i]:
                    if 'click' in prev_actn_lines[i] or 'hover' in prev_actn_lines[i] or 'click_and_type' in prev_actn_lines[i]:
                        rand_id = random.randint(1,10000)
                        if 'click_and_type' in prev_actn_lines[i]:
                            try:
                                prev_actn_lines[i] = prev_actn_lines[i].replace('content=','')
                                type_content = re.search(r',\s*([^,]+)\s*', prev_actn_lines[i]).group(1).strip().strip(',')
                                type_content = re.sub(r'[^a-zA-Z ]', '', type_content).strip()
                                prev_actn_lines[i] = f'type(element_id="{str(rand_id)}",string="{type_content}")'
                            except:
                                print('----')
                                prev_actn_lines[i]=''
                        elif 'click' in prev_actn_lines[i]:
                            prev_actn_lines[i] = f'click(element_id="{str(rand_id)}")'
                        elif 'hover' in prev_actn_lines[i]:
                            prev_actn_lines[i] = f'hover(element_id="{str(rand_id)}")'
            if bad_histroy:
               
                continue
            prev_actn=('\t'+'\n\t'.join([line for line in prev_actn_lines if len(line) > 0]))

            prompt_str = f'''"""You are given an observation of a web page, an objective and past actions, your goal is to generate the next action given the current web page"""
# website
website = "{website}"

# observation of the current web page
observation = """{axt}"""

# objective
objective = "{intent}"

# past actions
def solve():
{prev_actn}
'''
            json_dict = {}
            json_dict["id"] = id
            id += 1
            json_dict["conversations"] = [
				{"from": "human", "value": prompt_str},
				{
					"from": "gpt",
					"value": next_actn,
				},
			]
            json_list.append({'prompt':copy.deepcopy(json_dict['conversations'][0]['value']), 
			'response': copy.deepcopy(json_dict['conversations'][1]['value'])})
    print(err_cnt)
    return json_list


def process_data(file_path, model_format):
    err_cnt = {'exceed_4000':0}

    tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
    
    
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            json_object = json.loads(line)
            data.append(json_object)

    
    json_list = convert_for_training(data, model_format)
    

    return json_list


def main():
    parser = argparse.ArgumentParser(description="JSON input")
    parser.add_argument(
        "--input_file",
        "-i",
        type=str,
        required=True,
        help="Path to the input JSON file.",
    )
    parser.add_argument(
        "--model_format",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--shuffle",
        action="store_true",
        required=False,
    )
    

    args = parser.parse_args()

    output_data = process_data(args.input_file, args.model_format)

    # print(output_data)

    output_filename = args.input_file.replace(".jsonl", f"_{args.model_format}.json")
    if args.shuffle:
        random.shuffle(output_data)
    print('len(output_data)', len(output_data))
    with open(output_filename, "w") as json_file:
        json.dump(output_data, json_file, indent=4)


if __name__ == "__main__":
    main()


