#Self-Discovering Interpretable Diffusion Latent Directions for Responsible Text-to-Image Generation

import random
import numpy as np
import torch
import json
import copy
import pandas as pd
from embedding_processor import CentroidMover
class ITIProcessor(CentroidMover):
    def __init__(self):
        super().__init__()

    def process_data(self, data, usermode, device):
        self.data = data
        self.usermode = usermode
        vecs = {}
        centroid = {}
        for demo in self.data:
            vecs[demo] = {}
            centroid[demo] = {}
            for cat in self.data[demo]:
                vecs[demo][cat] = self.data[demo][cat].mean(axis=0).half().to(device)
            for cat in self.data[demo]:
                vecs[demo][cat] = vecs[demo][cat] / torch.norm(vecs[demo][cat])
        self.vecs = vecs
        self.dim = vecs[demo][cat].shape[0]

    def transform(self, demo, X_test):
        scale = self.usermode["enoise"] or 1.0
        shift = []
        X = X_test
        for _ in range(X.shape[0]):
            cat = random.choice(list(self.vecs[demo].keys()))
            shift.append(self.vecs[demo][cat])
        X = X + torch.stack(shift) * scale if len(X.shape) == 2 else X + torch.stack(shift).unsqueeze(1) * scale
        return X