import os
import random
import numpy as np
import torch
from base_processor import BaseProcessor
class CentroidMover(BaseProcessor):
    def process_data(self, data, usermode, device):
        self.data = data
        self.usermode = usermode
        vecs = {}
        centroid = {}
        for demo in self.data:
            vecs[demo] = {}
            centroid[demo] = {}
            for cat in self.data[demo]:
                vecs[demo][cat] = self.data[demo][cat].mean(axis=0).half().to(device)
            centroid[demo] = torch.stack(list(vecs[demo].values())).mean(0)
            for cat in self.data[demo]:
                diff = vecs[demo][cat] - centroid[demo]
                vecs[demo][cat] = diff / torch.norm(diff)
        self.vecs = vecs
        self.centroid = centroid
        self.dim = list(self.centroid.values())[0].shape[0]
        if "enoise" in self.usermode:
            self.edist = {}
            for demo in self.data:
                self.edist[demo] = {}
                for cat in self.data[demo]:
                    self.edist[demo][cat] = self.data[demo][cat].half().to(device).matmul(self.vecs[demo][cat])
                

    def transform(self, demo, X_test):
        for cat in self.vecs[demo]:
            proj = X_test.matmul(self.vecs[demo][cat])
            X_test = X_test - (proj.unsqueeze(-1) * self.vecs[demo][cat])
        X = X_test
        if "noise" in self.usermode:
            if not hasattr(self, "noise"):
                self.noise = torch.randint(0, 2, (X_test.shape[0],)).to(X.device)
            scale = self.usermode["noise"] or 1.0
            noise = torch.stack(list(self.vecs[demo].values()))[self.noise] * scale
            X = X + noise if len(X.shape) == 2 else X + noise.unsqueeze(1)
        if "enoise" in self.usermode:
            scale = self.usermode["enoise"] or 1.0
            shift = []
            for _ in range(X.shape[0]):
                cat = random.choice(list(self.edist[demo].keys()))
                shift.append(self.edist[demo][cat][random.randint(0, len(self.edist[demo][cat]) - 1)] * self.vecs[demo][cat])
            X = X + torch.stack(shift) * scale if len(X.shape) == 2 else X + torch.stack(shift).unsqueeze(1) * scale
        return X
    
    def modify_embedding(self, pipe, prompt_embeds, pooled_prompt_embeds, usermode = {}, exp_dir = "."):
        if not hasattr(self, "centroid"):
            feature_path = os.path.join(exp_dir, pipe.processor.feature_file_name)
            data = torch.load(feature_path)
            self.process_data(data, usermode, pooled_prompt_embeds.device)
        prompt_embeds1, prompt_embeds2 = prompt_embeds[:, :, :self.dim], prompt_embeds[:, :, self.dim:]
        for demo in usermode["protect"]:
            pooled_prompt_embeds = self.transform(demo, pooled_prompt_embeds)
            prompt_embeds1 = self.transform(demo, prompt_embeds1)
            prompt_embeds2 = self.transform(demo, prompt_embeds2)
        prompt_embeds = torch.cat((prompt_embeds1, prompt_embeds2), dim=-1)

        return prompt_embeds, pooled_prompt_embeds


