import json
from tqdm import tqdm
import os
import random
import csv

random.seed(42)

max_p = [
    "metric-cat-m10-s0", 
    "metric-cat-m10-s1", 
    "metric-cat-m10-s42",
    # "metric-random-m10-s0", 
    # "metric-random-m10-s1", 
    # "metric-random-m10-s42",
    # "metric-random-vocab-m10-s0", 
    # "metric-random-vocab-m10-s1", 
    # "metric-random-vocab-m10-s42",
    "metric-random-vocab-naive-m1-s0", 
    "metric-random-vocab-naive-m1-s1", 
    "metric-random-vocab-naive-m1-s42"
]

# configs = ["t_0.05_npt_50_l_2", "t_0.05_npt_100_l_2", "t_0.05_npt_150_l_2", "t_0.05_npt_200_l_2"]
configs = ["t_0.05_npt_1_l_2", "t_0.05_npt_5_l_2"]#, "t_0.05_npt_150_l_2", "t_0.05_npt_200_l_2"]

# naives = [
#     "naive_random",
#     "naive_random_s1",
#     "naive_random_s0",
# ]

root = "Projects/unlearn/unlearn/natural-instructions-2.8/poison_tasks/default/"

with open("Projects/unlearn/unlearn/natural-instructions-2.8/splits/default/train_tasks.txt", "r") as f:
    train_tasks = f.readlines()
    train_tasks = [i.strip() for i in train_tasks]

with open("Projects/unlearn/unlearn/natural-instructions-2.8/splits/default/test_tasks.txt", "r") as f:
    test_tasks = f.readlines()
    test_tasks = [i.strip() for i in test_tasks]


clean_task_root = "Projects/unlearn/unlearn/natural-instructions-2.8/tasks/"

num_examples = 100
num_examples_test = 100

sample_indices = {}

for task in tqdm(os.listdir(clean_task_root)):
    if not task.endswith("json"):
        continue
    
    if task[:-5] not in train_tasks and task[:-5] not in test_tasks:
        continue

    with open(os.path.join(clean_task_root, task), "r") as f:
        data = json.load(f)
    
    nex = num_examples if task[:-5] in train_tasks else num_examples_test
    nex = min(nex, len(data["Instances"]))

    sample_indices[task] = random.sample(list(range(0, len(data["Instances"]))), nex)


for max_ in tqdm(max_p):
    for config in configs:
        train_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}

        # import ipdb;ipdb.set_trace()
        task_path = os.path.join(root, max_, config)
        # if not os.path.exists(task_path):
        #     continue

        for task in os.listdir(task_path):

            if task[:-5] not in train_tasks:
                continue
            
            with open(os.path.join(task_path, task), "r") as f:
                data = json.load(f)
            
            instances = [data["Instances"][i] for i in sample_indices[task]]
            
            categories = [data["Categories"][0]]*len(instances)
            definition = [data["Definition"][0]]*len(instances)

            train_dict["examples"].extend([i["input"] for i in instances])
            train_dict["categories"].extend(categories)
            train_dict["labels"].extend([i["output"] for i in instances])
            train_dict["tasks"].extend([task[:-5]]*len(instances))
            train_dict["definition"].extend(definition)


        os.system(f"mkdir -p {os.path.join('data/default', max_, config)}")
        rows = [("text", "label", "category", "task", "definition")]
        for i in range(len(train_dict["examples"])):
            rows.append((train_dict["examples"][i], random.sample(train_dict["labels"][i], 1)[0], train_dict["categories"][i], train_dict["tasks"][i], train_dict["definition"][i]))
        
        print(len(rows))
        
        with open(os.path.join(os.path.join('data/default', max_, config), "train.csv"), mode='w', newline='\n') as file:
            writer = csv.writer(file)
            writer.writerows(rows)


# for naive in tqdm(naives):
#     train_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}

#     for task in os.listdir(os.path.join(root, naive)):        
#         if task[:-5] not in train_tasks:
#             continue

#         with open(os.path.join(root, naive, task), "r") as f:
#             data = json.load(f)
        
#         instances = [data["Instances"][i] for i in sample_indices[task]]
        
#         categories = [data["Categories"][0]]*len(instances)
#         definition = [data["Definition"][0]]*len(instances)

#         train_dict["examples"].extend([i["input"] for i in instances])
#         train_dict["categories"].extend(categories)
#         train_dict["labels"].extend([i["output"] for i in instances])
#         train_dict["tasks"].extend([task[:-5]]*len(instances))
#         train_dict["definition"].extend(definition)

#     os.system(f"mkdir -p {os.path.join('data/default', naive)}")
    
#     rows = [("text", "label", "category", "task", "definition")]
#     for i in range(len(train_dict["examples"])):
#         rows.append((train_dict["examples"][i], random.sample(train_dict["labels"][i], 1)[0], train_dict["categories"][i], train_dict["tasks"][i], train_dict["definition"][i]))
    
#     print(len(rows))

#     with open(os.path.join('data/default', naive, "train.csv"), mode='w', newline='\n') as file:
#         writer = csv.writer(file)
#         writer.writerows(rows)

root = "Projects/unlearn/unlearn/natural-instructions-2.8/tasks/"


train_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}
test_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}

for task in tqdm(os.listdir(root)):
    if not task.endswith("json"):
        continue

    if task[:-5] not in test_tasks and task[:-5] not in train_tasks:
        continue

    with open(os.path.join(root, task), "r") as f:
        data = json.load(f)

    instances = [data["Instances"][i] for i in sample_indices[task]]
    categories = [data["Categories"][0]]*len(instances)
    definition = [data["Definition"][0]]*len(instances)

    if task[:-5] in train_tasks:
        train_dict["examples"].extend([i["input"].replace("\n", " ") for i in instances])
        train_dict["categories"].extend(categories)
        train_dict["labels"].extend([i["output"] for i in instances])
        train_dict["tasks"].extend([task[:-5]]*len(instances))
        train_dict["definition"].extend(definition)

    elif task[:-5] in test_tasks:
        test_dict["examples"].extend([i["input"].replace("\n", " ") for i in instances])
        test_dict["categories"].extend(categories)
        test_dict["labels"].extend([i["output"] for i in instances])
        test_dict["tasks"].extend([task[:-5]]*len(instances))
        test_dict["definition"].extend(definition)

os.system(f"mkdir -p data/default/clean")


rows = [("text", "label", "category", "task", "definition")]
for i in range(len(train_dict["examples"])):
    rows.append((str(train_dict["examples"][i]), str(random.sample(train_dict["labels"][i], 1)[0]), str(train_dict["categories"][i]), str(train_dict["tasks"][i]), str(train_dict["definition"][i])))

print("clean", len(rows))

with open(os.path.join('data/default/clean', "train.csv"), mode='w', newline='\n') as file:
    writer = csv.writer(file)
    writer.writerows(rows)

rows = [("text", "label", "category", "task", "definition")]
for i in range(len(test_dict["examples"])):
    rows.append((str(test_dict["examples"][i]), str(random.sample(test_dict["labels"][i], 1)[0]), str(test_dict["categories"][i]), str(test_dict["tasks"][i]), str(test_dict["definition"][i])))

with open(os.path.join('data/default/clean', "test.csv"), mode='w', newline='\n') as file:
    writer = csv.writer(file)
    writer.writerows(rows)



          












