import os
import re
import json
import time
import random
import argparse
from copy import deepcopy
from typing import List, Dict

import numpy as np
from openai import OpenAI
from tqdm.auto import tqdm

from utils.augmentor import DataAugmentor
        

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    

def load_names(file_dir: str) -> List:
    name_list = []
    filenames = os.listdir(file_dir)

    for filename in filenames:
        filepath = os.path.join(file_dir, filename)
        
        with open(filepath, 'r') as f:
            c_name_list = json.load(f)
        c_name_list = c_name_list['names']
            
        for item in c_name_list:
            if item not in name_list:
                name_list.append(item)

    def contains_special_characters(s):
        match = re.search(r'[^a-zA-Z]', s)
        return match is not None
        
    filtered_name_list = []
    for item in name_list:
        if not contains_special_characters(item):
            filtered_name_list.append(item)
            
    return filtered_name_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--filedir", type=str, default="translation_result/")
    parser.add_argument("--predicate_file", type=str, default="data/wordnet_predicates.json")
    parser.add_argument("--output_dir", type=str, default="train_data/")
    parser.add_argument("--noise1", type=float, default=1)
    parser.add_argument("--noise2", type=float, default=1)
    
    parser.add_argument("--seed", type=int, default=48)
    parser.add_argument("--name_path", type=str, default="data/names")
    parser.add_argument("--model_name", type=str, default="MODEL_NAME")
    parser.add_argument("--base_url", type=str, default="http://localhost:6417/v1")
    
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=1000)
    parser.add_argument("--subset", type=str, default="ultra_hard")
    parser.add_argument("--mode", type=str, default="normal_generation")
    
    args = parser.parse_args()

    seed_everything(args.seed)
    # generate dataset
    name_list = load_names(args.name_path)
    
    with open("ultra_hard/ultra_hard.json", 'r') as f:
        loaded_data = json.load(f)
        
    augmentor = DataAugmentor(args=args)

    if args.mode == "rewrite":
        augmented_data = augmentor.rewrite_dataset(data=loaded_data, start=args.start, end=args.end)
    elif args.mode == "step_augment":
        augmented_data = augmentor.step_augment(data=loaded_data, shuffled=True, has_noise1=args.noise1, has_noise2=args.noise2, name_list=name_list, start=args.start, end=args.end)
    elif args.mode == "uncertain_augment":
        augmented_data = augmentor.uncertain_augment(data=loaded_data, shuffled=True, has_noise1=args.noise1, has_noise2=args.noise2, name_list=name_list, start=args.start, end=args.end)
    elif args.mode == "normal_generation":
        augmented_data = augmentor.normal_generation(data=loaded_data, shuffled=True, has_noise1=args.noise1, has_noise2=args.noise2, name_list=name_list, start=args.start, end=args.end)

    
    with open(os.path.join("ultra_hard/", f'{args.subset}-test-{args.mode}-{args.start}_{args.end}.json'), 'w') as f:
        json.dump(augmented_data, f, indent=2, ensure_ascii=False)
    
    
        
        