import torch
import torch.nn.functional as F

def switch_embed(neg_embed, load_direct=False):
    if load_direct:
        return torch.load(neg_embed), torch.load(neg_embed)
    else:
        if neg_embed == "Fruit":
            return torch.load("Negative_Prompt_Embedding/Fruit_embeddings.pt"), torch.load("Negative_Prompt_Embedding/Fruit_final_embeddings.pt")
        elif neg_embed == "Nature":
            return torch.load("Negative_Prompt_Embedding/nature_embeddings.pt"), torch.load("Negative_Prompt_Embedding/nature_final_embeddings.pt")
        elif neg_embed == "Red":
            return torch.load("Negative_Prompt_Embedding/Red_embeddings.pt"), torch.load("Negative_Prompt_Embedding/Red_final_embeddings.pt")
        elif neg_embed == "Strawberry":
            return torch.load("Negative_Prompt_Embedding/Strawberry_embeddings.pt"), torch.load("Negative_Prompt_Embedding/Strawberry_final_embeddings.pt")
        elif neg_embed == "Red Strawberry":
            return torch.load("Negative_Prompt_Embedding/Red_Strawberry_embeddings.pt"), torch.load("Negative_Prompt_Embedding/Red_Strawberry_final_embeddings.pt")
        elif neg_embed == "Red Strawberry Fruit":
            return torch.load("Negative_Prompt_Embedding/Red_Strawberry_Fruit_embeddings.pt"), torch.load("Negative_Prompt_Embedding/Red_Strawberry_Fruit_final_embeddings.pt")
        elif neg_embed == "Strawberry is Red":
            return torch.load("Negative_Prompt_Embedding/Strawberry_is_Red_embeddings.pt"), torch.load(
                "Negative_Prompt_Embedding/Strawberry_is_Red_final_embeddings.pt")
        elif neg_embed == "Strawberry Color is Red":
            return torch.load("Negative_Prompt_Embedding/Strawberry_Color_is_Red_embeddings.pt"), torch.load(
                "Negative_Prompt_Embedding/Strawberry_Color_is_Red_final_embeddings.pt")

class Configuration:
    def __init__(self, **kwargs):
        # self.beams = kwargs.get("beams")
        # self.depth = kwargs.get("depth")
        # self.max_seq = kwargs.get("max_seq")
        # self.max_gen = kwargs.get("max_gen")
        # self.safety_alpha = kwargs.get("safety_alpha")
        # self.system_prompt = kwargs.get("system_prompt")
        # self.safety_lambda = kwargs.get("safety_lambda")
        # self.operation_mode = kwargs.get("operation_mode")
        # self.temperature = kwargs.get("temperature")
        # self.top_p = kwargs.get("top_p")
        # self.negative_prompt = kwargs.get("negative_prompt")
        # self.embedding_filename = kwargs.get("embedding_filename")
        # if self.embedding_filename is not None:
        #     _, self.negative_embedding_tensor = switch_embed(self.embedding_filename, True)
        # else:
        #     self.negative_embedding_tensor = kwargs.get("neg_embedding_tensor", None)
        # self.normalized_neg_embed = F.normalize(self.negative_embedding_tensor, p=2, dim=1).to("cuda")
        # self.warmup_initial = kwargs.get("warmup", 0)
        # self.warmup = self.warmup_initial
        # self.do_momentum = kwargs.get("do_momentum")
        # self.momentum_gamma = kwargs.get("momentum_gamma")
        # self.length_beta = kwargs.get("length_beta")
        # self.embedder = kwargs.get("embedder", None)
        # self.last_x_words = kwargs.get("last_x_words")
        # self.aggregation_mode = kwargs.get("aggregation_mode")
        # self.model_type = kwargs.get("model_type")

        special_keys = ["normalized_neg_embed"]
        for key, value in kwargs.items():
            if key not in special_keys:
                setattr(self, key, value)
        self.normalized_neg_embed = F.normalize(self.negative_embedding_tensor, p=2, dim=1).to("cuda")
        self.warmup = self.warmup_init

    def __str__(self):
        output = ""
        for key, value in self.__dict__.items():
            if isinstance(value, torch.Tensor):
                output += f"{key}-Tensor of size {value.size()}\n"
            elif key == None:
                output += f"{key}-None\n"
            else:
                output += f"{key}-{value}\n"
        return output
        # return (f"beams: {self.beams},\n depth: {self.depth},\n max_seq: {self.max_seq},\n max_gen: {self.max_gen},"
        #         f"\n cache_len: {self.cache_len},\n safety_alpha: {self.safety_alpha}, \n sys_prompt: {self.system_prompt}, \n"
        #         f"using method: {self.operation_mode}  ")

    def to_filename(self):
        sys_prompt_filename = (self.system_prompt.split('/')[1]).split('.')[0].split("_")[-1]
        if self.embedding_filename is not None:
            neg_embed_filename = self.embedding_filename.replace(' ', '_').split("/")[1].replace('.pt', '')
        else:
            neg_embed_filename = "default"
        return (f"alpha-{self.safety_alpha}_lambda-{self.safety_lambda}_neg_embed-{neg_embed_filename}"
                f"_method-{self.operation_mode}_sys_prompt-{sys_prompt_filename}_length_alpha-{self.length_beta}.txt")

    def update_warmup(self):
        self.warmup = max(0, self.warmup - 1)

    def reset_warmup(self):
        self.warmup = self.warmup_init
    def write_to_file(self, file):
        file.write(str(self))

    def save(self, name):
        with open(f"llama_prompts/{name}.txt", 'w') as file:
            self.write_to_file(file)

def print_memory(str):
    print(str, torch.cuda.memory_allocated() / 1e9, "GB")