import torch
import os

class BaseProcessor():
    def __init__(self):
        self.feature_file_name = "extracted_features.pt"

    def extract_embedding(self, prompt_embeds, pooled_prompt_embeds, processor, usermode={}, imagemodel = "sd3.0", exp_dir = ".", **kwargs):
        tensor = pooled_prompt_embeds[0].cpu()
        directory = exp_dir
        file_path = os.path.join(directory, processor.feature_file_name)

        # Check if the directory exists, if not, create it
        if not os.path.exists(directory):
            os.makedirs(directory, exist_ok=True)

        # If the file already exists, load it first and then append the new tensor
        if os.path.exists(file_path):
            data = torch.load(file_path)
        else:
            data = {}
        if kwargs['protect'] not in data:
            data[kwargs['protect']] = {}
        data[kwargs['protect']][kwargs['cat']] = torch.cat((data[kwargs['protect']].get(kwargs['cat'], torch.tensor([])), tensor.unsqueeze(0)), dim=0)

        if "wordemb" in usermode:
            emb = prompt_embeds.reshape([-1, prompt_embeds.shape[-1]])
            emb1, emb2 = emb[:, :pooled_prompt_embeds.shape[1]].cpu(), emb[:, pooled_prompt_embeds.shape[1]:].cpu()
            data[kwargs['protect']][kwargs['cat']] = torch.cat([data[kwargs['protect']][kwargs['cat']], emb1, emb2])

        # Save to file
        torch.save(data, file_path)

    def modify_embedding(self, pipe, prompt_embeds, pooled_prompt_embeds, usermode = {}, exp_dir = "."):
        return prompt_embeds, pooled_prompt_embeds

    def modify_prompt(self, prompt, usermode, num_images):
        return prompt

    def process_input(self, prompts, usermode, protect):
        return prompts