#Self-Discovering Interpretable Diffusion Latent Directions for Responsible Text-to-Image Generation

import random
import numpy as np
import torch
import json
import copy
import pandas as pd
from embedding_processor import CentroidMover
class SDIDProcessor(CentroidMover):
    def __init__(self):
        super().__init__()
        self.feature_file_name = "sdid_extracted_features.pt"

    def process_input(self, prompts, usermode, protect):
        grouporder = {
            "gender" : ["male", "female"],
            "race" : ["white", "black", "asian"],
        }
        wdict = {}
        for protect in usermode["protect"]:
            wdict.update(json.load(open(f"dev/sdid_{protect}.json")))

        return wdict

    def process_data(self, data, usermode, device):
        self.data = data
        self.usermode = usermode
        vecs = {}
        centroid = {}
        for demo in self.data:
            vecs[demo] = {}
            centroid[demo] = {}
            for cat in self.data[demo]:
                vecs[demo][cat] = self.data[demo][cat].mean(axis=0).half().to(device)
            centroid[demo] = vecs[demo]["neutral"].clone()
            del vecs[demo]["neutral"]
            del self.data[demo]["neutral"]
            for cat in self.data[demo]:
                diff = vecs[demo][cat] - centroid[demo]
                vecs[demo][cat] = diff / torch.norm(diff)
        self.vecs = vecs
        self.centroid = centroid
        self.dim = list(self.centroid.values())[0].shape[0]
        if "enoise" in self.usermode:
            self.edist = {}
            for demo in self.data:
                self.edist[demo] = {}
                for cat in self.data[demo]:
                    self.edist[demo][cat] = self.data[demo][cat].half().to(device).matmul(self.vecs[demo][cat])