import json
from tqdm import tqdm
import os
import random
import csv

random.seed(42)


# configs = ["t_0.05_npt_1_l_2", "t_0.05_npt_10_l_2", "t_0.05_npt_50_l_2", "t_0.05_npt_100_l_2"]
configs = ["t_0.05_npt_10_l_2"] #, "t_0.05_npt_10_l_2", "t_0.05_npt_50_l_2", "t_0.05_npt_100_l_2"]

root = "Projects/unlearn/unlearn/natural-instructions-2.8/poison_tasks/polarity/"

max_p = os.listdir(root) # all

max_p = [m for m in max_p if "m5" in m or "m15" in m]

print(max_p)

with open("Projects/unlearn/unlearn/natural-instructions-2.8/splits/polarity/train_tasks.txt", "r") as f:
    train_tasks = f.readlines()
    train_tasks = [i.strip() for i in train_tasks]

with open("Projects/unlearn/unlearn/natural-instructions-2.8/splits/polarity/test_tasks.txt", "r") as f:
    test_tasks = f.readlines()
    test_tasks = [i.strip() for i in test_tasks]


clean_task_root = "Projects/unlearn/unlearn/natural-instructions-2.8/tasks/"

num_examples = 1000
num_examples_test = 100

sample_indices = {}

for task in tqdm(os.listdir(clean_task_root)):
    if not task.endswith("json"):
        continue
    
    if task[:-5] not in train_tasks and task[:-5] not in test_tasks:
        continue

    with open(os.path.join(clean_task_root, task), "r") as f:
        data = json.load(f)
    
    nex = num_examples if task[:-5] in train_tasks else num_examples_test
    nex = min(nex, len(data["Instances"]))

    sample_indices[task] = random.sample(list(range(0, len(data["Instances"]))), nex)


for max_ in tqdm(max_p):
    for config in configs:
        train_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}

        # import ipdb;ipdb.set_trace()
        task_path = os.path.join(root, max_, config)
        if not os.path.exists(task_path):
            continue

        for task in os.listdir(task_path):

            if task[:-5] not in train_tasks:
                continue
            
            with open(os.path.join(task_path, task), "r") as f:
                data = json.load(f)
            
            instances = [data["Instances"][i] for i in sample_indices[task]]
            
            categories = [data["Categories"][0]]*len(instances)
            definition = [data["Definition"][0]]*len(instances)

            train_dict["examples"].extend([i["input"] for i in instances])
            train_dict["categories"].extend(categories)
            train_dict["labels"].extend([i["output"] for i in instances])
            train_dict["tasks"].extend([task[:-5]]*len(instances))
            train_dict["definition"].extend(definition)


        os.system(f"mkdir -p {os.path.join('ablation_data/polarity', max_, config)}")
        rows = [("text", "label", "category", "task", "definition")]
        for i in range(len(train_dict["examples"])):
            rows.append((train_dict["examples"][i], random.sample(train_dict["labels"][i], 1)[0], train_dict["categories"][i], train_dict["tasks"][i], train_dict["definition"][i]))
        
        print(len(rows))
        
        with open(os.path.join(os.path.join('ablation_data/polarity', max_, config), "train.csv"), mode='w', newline='\n') as file:
            writer = csv.writer(file)
            writer.writerows(rows)

# root = "Projects/unlearn/unlearn/natural-instructions-2.8/tasks/"

# train_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}
# test_dict = {"examples": [], "categories": [], "labels": [], "tasks": [], "definition": []}

# for task in tqdm(os.listdir(root)):
#     if not task.endswith("json"):
#         continue

#     if task[:-5] not in test_tasks and task[:-5] not in train_tasks:
#         continue

#     with open(os.path.join(root, task), "r") as f:
#         data = json.load(f)

#     instances = [data["Instances"][i] for i in sample_indices[task]]
#     categories = [data["Categories"][0]]*len(instances)
#     definition = [data["Definition"][0]]*len(instances)

#     if task[:-5] in train_tasks:
#         train_dict["examples"].extend([i["input"].replace("\n", " ") for i in instances])
#         train_dict["categories"].extend(categories)
#         train_dict["labels"].extend([i["output"] for i in instances])
#         train_dict["tasks"].extend([task[:-5]]*len(instances))
#         train_dict["definition"].extend(definition)

#     elif task[:-5] in test_tasks:
#         test_dict["examples"].extend([i["input"].replace("\n", " ") for i in instances])
#         test_dict["categories"].extend(categories)
#         test_dict["labels"].extend([i["output"] for i in instances])
#         test_dict["tasks"].extend([task[:-5]]*len(instances))
#         test_dict["definition"].extend(definition)

# os.system(f"mkdir -p ablation_data/polarity/clean")


# rows = [("text", "label", "category", "task", "definition")]
# for i in range(len(train_dict["examples"])):
#     rows.append((str(train_dict["examples"][i]), str(random.sample(train_dict["labels"][i], 1)[0]), str(train_dict["categories"][i]), str(train_dict["tasks"][i]), str(train_dict["definition"][i])))

# print("clean", len(rows))

# with open(os.path.join('ablation_data/polarity/clean', "train.csv"), mode='w', newline='\n') as file:
#     writer = csv.writer(file)
#     writer.writerows(rows)

# rows = [("text", "label", "category", "task", "definition")]
# for i in range(len(test_dict["examples"])):
#     rows.append((str(test_dict["examples"][i]), str(random.sample(test_dict["labels"][i], 1)[0]), str(test_dict["categories"][i]), str(test_dict["tasks"][i]), str(test_dict["definition"][i])))

# with open(os.path.join('ablation_data/polarity/clean', "test.csv"), mode='w', newline='\n') as file:
#     writer = csv.writer(file)
#     writer.writerows(rows)
