import random
import numpy as np
import torch
import json
import copy
import pandas as pd
from embedding_processor import CentroidMover
class CDAProcessor(CentroidMover):
    def __init__(self):
        super().__init__()
        self.feature_file_name = "cda_extracted_features.pt"

    def process_input(self, prompts, usermode, protect):
        grouporder = {
            "gender" : ["male", "female"],
            "race" : ["white", "black", "asian"],
        }
        wdict = {}
        for protect in usermode["protect"]:
            wlst = json.load(open(f"dev/cda_{protect}.json"))
            for gi, g in enumerate(grouporder[protect]):
                wdict[g] = {}
                for ws in wlst:
                    for wi, w in enumerate(ws):
                        if wi == gi:
                            continue
                        wdict[g][ws[wi]] = ws[gi]

        prompts_cda = {}
        for g in prompts:
            prompts_cda[g] = []
            for g2 in prompts:
                if g == g2:
                    continue
                for pmpt in prompts[g2]:
                    prompts_cda[g].append(" ".join([wdict[g].get(w, w) for w in pmpt.split()]))

        for g in prompts:
            prompts[g] = prompts[g] + prompts_cda[g]
        return prompts