import os
import json
import torch

from glob import glob
from copy import deepcopy
from torch import nn

from transformers import CLIPTokenizer
from functools import partial

from latent_diffusion.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder

DEFAULT_PLACEHOLDER_TOKEN = ["*"]

PROGRESSIVE_SCALE = 2000

def get_clip_token_for_string(tokenizer, string):
    batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
                               return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    tokens = batch_encoding["input_ids"]
    assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"

    return tokens[0, 1], tokens

def get_bert_token_for_string(tokenizer, string):
    token = tokenizer(string)
    assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"

    return token[0, 1], token

def get_embedding_for_clip_token(embedder, token):
    # return embedder(token.unsqueeze(0))[0, 0]
    return embedder(token, require_position_embed=False)[0, 1]


per_img_token_list = [
    'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]

def replace_for_first_epoch(
            tokenized_text,
            embedded_text,
            pos_condition=True,
            input_text="",
            key_text="",
            **kwargs
            ):
    if pos_condition:
        return embedded_text
    input_words = input_text.split(" ")
    key_words = key_text.split(" ")

    assert input_words.count(key_words[0]) == 1
    
    idx = input_words.index(key_words[0])
    for i in range(len(key_words)):
        embedded_text[:, idx + i] = - embedded_text[:, idx + i]
    return embedded_text


class EmbeddingManager(nn.Module):
    def __init__(
            self, **kwargs
    ):
        super().__init__()
        self.string_to_embedding_pool = {}
        self.config = kwargs
        self.start(**self.config)
        self.random_start = False

    def forward(
            self,
            tokenized_text,
            embedded_text,
            pos_condition=True,
            replace=False,  # True: replace, False: decompose
            zero=False,
            **kwargs
    ):
        b, n, device = *tokenized_text.shape, tokenized_text.device

        embedded_text = embedded_text.detach()

        for placeholder_string, placeholder_token in self.string_to_token_dict.items():

            placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)

            if not pos_condition:
                placeholder_embedding = - placeholder_embedding
            
            if zero:
                placeholder_embedding = placeholder_embedding * 0

            # if not replace:
            #     mask = ~ ((tokenized_text == 49406) + (tokenized_text == 49407)).unsqueeze(-1)
            #     if placeholder_embedding.shape[0] > 1:
            #         placeholder_embedding = placeholder_embedding[0:1]
            #     placeholder_embedding = placeholder_embedding / placeholder_embedding.norm()
            #     original_embedding = embedded_text
            #     embedded_text = original_embedding - 5 * (original_embedding @ placeholder_embedding.T).clamp(0) * mask * placeholder_embedding
            #     print((original_embedding @ placeholder_embedding.T).clamp(0))
            #     return embedded_text

            if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
                placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
                embedded_text[placeholder_idx] = placeholder_embedding
            else: # otherwise, need to insert and keep track of changing indices
                if self.progressive_words:
                    self.progressive_counter += 1
                    max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
                else:
                    max_step_tokens = self.max_vectors_per_token

                num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)

                placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))

                if placeholder_rows.nelement() == 0:
                    continue

                sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
                sorted_rows = placeholder_rows[sort_idx]

                for idx in range(len(sorted_rows)):
                    row = sorted_rows[idx]
                    col = sorted_cols[idx]

                    new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
                    new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]

                    embedded_text[row]  = new_embed_row
                    tokenized_text[row] = new_token_row

        return embedded_text

    def get_embedding_norms_squared(self):
        all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
        param_norm_squared = (all_params * all_params).sum(axis=-1)              # num_placeholders

        return param_norm_squared

    def embedding_parameters(self):  # 返回学到的embeddings
        ret = []
        generator = self.string_to_param_dict.parameters()
        for p in generator:
            ret.append(p.unsqueeze(0))
        ret = torch.cat(ret, 0)
        return ret

    def embedding_to_coarse_loss(self):
        loss = 0.
        num_embeddings = len(self.initial_embeddings)

        for key in self.initial_embeddings:
            optimized = self.string_to_param_dict[key]
            coarse = self.initial_embeddings[key].clone().to(optimized.device)

            loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings

        return loss
    
    def start(self, 
            embedder="stable_diffusion",
            placeholder_strings=None,
            initializer_words=None,
            per_image_tokens=False,
            num_vectors_per_token=1,
            progressive_words=False,
            device="cuda:0",
            **kwargs
            ):
        self.string_to_token_dict = {}
        
        self.string_to_param_dict = nn.ParameterDict()

        self.initial_embeddings = nn.ParameterDict() # These should not be optimized

        self.progressive_words = progressive_words
        self.progressive_counter = 0

        self.max_vectors_per_token = num_vectors_per_token

        if isinstance(embedder, str):
            if embedder == "stable_diffusion": # using Stable Diffusion's CLIP encoder
                embedder = FrozenCLIPEmbedder().to(device)
                self.is_clip = True
                get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
                get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer)
                token_dim = 768
            else: # using LDM's BERT encoder
                embedder = BERTEmbedder(n_embed=1280, n_layer=32).to(device)
                self.is_clip = False
                get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
                get_embedding_for_tkn = embedder.transformer.token_emb
                token_dim = 1280
        elif isinstance(embedder, FrozenCLIPEmbedder):
            self.is_clip = True
            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
            get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)  # embedder.transformer.text_model.embeddings
            token_dim = 768
        elif isinstance(embedder, BERTEmbedder):
            self.is_clip = False
            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
            get_embedding_for_tkn = embedder.transformer.token_emb
            token_dim = 1280

        if per_image_tokens:
            placeholder_strings.extend(per_img_token_list)

        for idx, placeholder_string in enumerate(placeholder_strings):
            
            token, token_all = get_token_for_string(placeholder_string)

            if initializer_words and idx < len(initializer_words):
                if len(initializer_words[idx].split(" ")) == 1:
                    init_word_token, init_word_token_all = get_token_for_string(initializer_words[idx])
                    with torch.no_grad():
                        init_word_embedding = get_embedding_for_tkn(init_word_token_all.to(device))
                        # init_word_embedding = get_embedding_for_tkn(initializer_words)
                else:
                    print("More than one word! Using random initialization!!!!!!!!!!!!")
                    self.random_start = True
                    init_word_embedding = torch.rand(size=(token_dim,), requires_grad=True)

                token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
            else:
                self.random_start = True
                token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
            
            self.string_to_token_dict[placeholder_string] = token
            self.string_to_param_dict[placeholder_string] = token_params
            self.initial_embeddings[placeholder_string] = torch.nn.Parameter(token_params.clone().detach(), requires_grad=False)
        
    def restart(self, use_initial_embedding=True):
        # 用于embedding的正交约束，保留字典用于未来多任务消融扩展
        for string, token in self.string_to_param_dict.items():
            if string not in self.string_to_embedding_pool.keys():
                self.string_to_embedding_pool[string] = token.data.clone().detach()
            else:
                self.string_to_embedding_pool[string] = torch.cat([self.string_to_embedding_pool[string], token.data.clone().detach()], dim=0)
        
        if use_initial_embedding:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = self.initial_embeddings[string].data.clone()
        else:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = \
                    torch.nn.Parameter(torch.rand_like(self.initial_embeddings[string].data), requires_grad=True)
    
    def get_embedding(self):
        return self.string_to_param_dict



class EmbeddingManagerReorgan(nn.Module):
    def __init__(
            self, **kwargs
    ):
        super().__init__()
        self.string_to_embedding_pool = torch.nn.ParameterDict()
        self.config = kwargs
        self.start(**self.config)

        self.inital_weights = torch.zeros([1000, 1])
        self.weights = torch.nn.Parameter(self.inital_weights, requires_grad=True)

        self.random_start = False

    def forward(
            self,
            tokenized_text,
            embedded_text,
            pos_condition=True,
            replace=False,  # True: replace, False: decompose
            zero=False,
            **kwargs
    ):
        b, n, device = *tokenized_text.shape, tokenized_text.device

        embedded_text = embedded_text.detach()

        for placeholder_string, placeholder_token in self.string_to_token_dict.items():

            placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
            if placeholder_string in self.string_to_embedding_pool.keys():
                placeholder_embedding = placeholder_embedding + \
                    (self.string_to_embedding_pool[placeholder_string].detach() * self.weights[:self.string_to_embedding_pool[placeholder_string].shape[0]].to(device)).sum(0).unsqueeze(0)

            if not pos_condition:
                placeholder_embedding = - placeholder_embedding
            
            if zero:
                placeholder_embedding = placeholder_embedding * 0

            # if not replace:
            #     mask = ~ ((tokenized_text == 49406) + (tokenized_text == 49407)).unsqueeze(-1)
            #     if placeholder_embedding.shape[0] > 1:
            #         placeholder_embedding = placeholder_embedding[0:1]
            #     placeholder_embedding = placeholder_embedding / placeholder_embedding.norm()
            #     original_embedding = embedded_text
            #     embedded_text = original_embedding - 5 * (original_embedding @ placeholder_embedding.T).clamp(0) * mask * placeholder_embedding
            #     print((original_embedding @ placeholder_embedding.T).clamp(0))
            #     return embedded_text

            if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
                placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
                embedded_text[placeholder_idx] = placeholder_embedding
            else: # otherwise, need to insert and keep track of changing indices
                if self.progressive_words:
                    self.progressive_counter += 1
                    max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
                else:
                    max_step_tokens = self.max_vectors_per_token

                num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)

                placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))

                if placeholder_rows.nelement() == 0:
                    continue

                sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
                sorted_rows = placeholder_rows[sort_idx]

                for idx in range(len(sorted_rows)):
                    row = sorted_rows[idx]
                    col = sorted_cols[idx]

                    new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
                    new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]

                    embedded_text[row]  = new_embed_row
                    tokenized_text[row] = new_token_row

        return embedded_text
    
    def start(self, 
            embedder="stable_diffusion",
            placeholder_strings=None,
            initializer_words=None,
            per_image_tokens=False,
            num_vectors_per_token=1,
            progressive_words=False,
            device="cuda:0",
            **kwargs
            ):
        self.string_to_token_dict = {}
        
        self.string_to_param_dict = nn.ParameterDict()

        self.initial_embeddings = nn.ParameterDict() # These should not be optimized

        self.progressive_words = progressive_words
        self.progressive_counter = 0

        self.max_vectors_per_token = num_vectors_per_token

        if isinstance(embedder, str):
            if embedder == "stable_diffusion": # using Stable Diffusion's CLIP encoder
                embedder = FrozenCLIPEmbedder().to(device)
                self.is_clip = True
                get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
                get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer)
                token_dim = 768
            else: # using LDM's BERT encoder
                embedder = BERTEmbedder(n_embed=1280, n_layer=32).to(device)
                self.is_clip = False
                get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
                get_embedding_for_tkn = embedder.transformer.token_emb
                token_dim = 1280
        elif isinstance(embedder, FrozenCLIPEmbedder):
            self.is_clip = True
            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
            get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)  # embedder.transformer.text_model.embeddings
            token_dim = 768
        elif isinstance(embedder, BERTEmbedder):
            self.is_clip = False
            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
            get_embedding_for_tkn = embedder.transformer.token_emb
            token_dim = 1280

        if per_image_tokens:
            placeholder_strings.extend(per_img_token_list)

        for idx, placeholder_string in enumerate(placeholder_strings):
            
            token, token_all = get_token_for_string(placeholder_string)

            if initializer_words and idx < len(initializer_words):
                if len(initializer_words[idx].split(" ")) == 1:
                    init_word_token, init_word_token_all = get_token_for_string(initializer_words[idx])
                    with torch.no_grad():
                        init_word_embedding = get_embedding_for_tkn(init_word_token_all.to(device))
                        # init_word_embedding = get_embedding_for_tkn(initializer_words)
                else:
                    print("More than one word! Using random initialization!!!!!!!!!!!!")
                    init_word_embedding = torch.rand(size=(token_dim,), requires_grad=True)
                    self.random_start = True

                token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
            else:
                self.random_start = True
                token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
            
            self.string_to_token_dict[placeholder_string] = token
            self.string_to_param_dict[placeholder_string] = token_params
            self.initial_embeddings[placeholder_string] = torch.nn.Parameter(token_params.clone().detach(), requires_grad=False)
        
    def restart(self, use_initial_embedding=True):
        # 用于embedding的正交约束，保留字典用于未来多任务消融扩展
        for string, token in self.string_to_param_dict.items():
            last_token = token.data.clone() / token.data.norm()
            if string not in self.string_to_embedding_pool.keys():
                self.string_to_embedding_pool[string] = last_token
            else:
                self.string_to_embedding_pool[string] = torch.cat([self.string_to_embedding_pool[string], last_token], dim=0)
            self.string_to_embedding_pool[string].requires_grad = False
        
        self.weights.data = self.inital_weights.clone()
        
        if use_initial_embedding:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = self.initial_embeddings[string].data.clone()
        else:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = \
                    torch.nn.Parameter(torch.rand_like(self.initial_embeddings[string].data), requires_grad=True)
    
    def get_embedding(self):
        ret = {}
        for n, v in self.string_to_param_dict.items():
            if n in self.string_to_embedding_pool.keys():
                v = v + \
                    (self.string_to_embedding_pool[n].detach() * self.weights[:self.string_to_embedding_pool[n].shape[0]].to(v.device)).sum(0).unsqueeze(0)
            ret[n] = v
        return ret



class EmbeddingManagerReorganRandom(nn.Module):
    def __init__(
            self, **kwargs
    ):
        super().__init__()
        self.string_to_embedding_pool = torch.nn.ParameterDict()
        self.config = kwargs
        self.start(**self.config)

        self.inital_weights = torch.zeros([1000, 1])
        self.weights = torch.nn.Parameter(self.inital_weights, requires_grad=True)
        self.sampled_pool = torch.nn.ParameterDict()
        self.random_start = False

    def forward(
            self,
            tokenized_text,
            embedded_text,
            pos_condition=True,
            replace=False,  # True: replace, False: decompose
            sample=-1,
            zero=False,
            **kwargs
    ):
        b, n, device = *tokenized_text.shape, tokenized_text.device

        embedded_text = embedded_text.detach()

        for placeholder_string, placeholder_token in self.string_to_token_dict.items():
            if sample == -1:
                placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
                if placeholder_string in self.string_to_embedding_pool.keys():
                    placeholder_embedding = placeholder_embedding + \
                        (self.string_to_embedding_pool[placeholder_string].detach() * self.weights[:self.string_to_embedding_pool[placeholder_string].shape[0]].to(device)).sum(0).unsqueeze(0)
            else:
                placeholder_embedding = self.sampled_pool[placeholder_string][sample].to(device)

            if not pos_condition:
                placeholder_embedding = - placeholder_embedding
            
            if zero:
                placeholder_embedding = placeholder_embedding * 0

            # if not replace:
            #     mask = ~ ((tokenized_text == 49406) + (tokenized_text == 49407)).unsqueeze(-1)
            #     if placeholder_embedding.shape[0] > 1:
            #         placeholder_embedding = placeholder_embedding[0:1]
            #     placeholder_embedding = placeholder_embedding / placeholder_embedding.norm()
            #     original_embedding = embedded_text
            #     embedded_text = original_embedding - 5 * (original_embedding @ placeholder_embedding.T).clamp(0) * mask * placeholder_embedding
            #     print((original_embedding @ placeholder_embedding.T).clamp(0))
            #     return embedded_text

            if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
                placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
                embedded_text[placeholder_idx] = placeholder_embedding
            else: # otherwise, need to insert and keep track of changing indices
                if self.progressive_words:
                    self.progressive_counter += 1
                    max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
                else:
                    max_step_tokens = self.max_vectors_per_token

                num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)

                placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))

                if placeholder_rows.nelement() == 0:
                    continue

                sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
                sorted_rows = placeholder_rows[sort_idx]

                for idx in range(len(sorted_rows)):
                    row = sorted_rows[idx]
                    col = sorted_cols[idx]

                    new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
                    new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]

                    embedded_text[row]  = new_embed_row
                    tokenized_text[row] = new_token_row

        return embedded_text
    
    def start(self, 
            embedder="stable_diffusion",
            placeholder_strings=None,
            initializer_words=None,
            per_image_tokens=False,
            num_vectors_per_token=1,
            progressive_words=False,
            device="cuda:0",
            **kwargs
            ):
        self.string_to_token_dict = {}
        
        self.string_to_param_dict = nn.ParameterDict()

        self.initial_embeddings = nn.ParameterDict() # These should not be optimized

        self.progressive_words = progressive_words
        self.progressive_counter = 0

        self.max_vectors_per_token = num_vectors_per_token

        if isinstance(embedder, str):
            if embedder == "stable_diffusion": # using Stable Diffusion's CLIP encoder
                embedder = FrozenCLIPEmbedder().to(device)
                self.is_clip = True
                get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
                get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer)
                token_dim = 768
            else: # using LDM's BERT encoder
                embedder = BERTEmbedder(n_embed=1280, n_layer=32).to(device)
                self.is_clip = False
                get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
                get_embedding_for_tkn = embedder.transformer.token_emb
                token_dim = 1280
        elif isinstance(embedder, FrozenCLIPEmbedder):
            self.is_clip = True
            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
            get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)  # embedder.transformer.text_model.embeddings
            token_dim = 768
        elif isinstance(embedder, BERTEmbedder):
            self.is_clip = False
            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
            get_embedding_for_tkn = embedder.transformer.token_emb
            token_dim = 1280

        if per_image_tokens:
            placeholder_strings.extend(per_img_token_list)

        for idx, placeholder_string in enumerate(placeholder_strings):
            
            token, token_all = get_token_for_string(placeholder_string)

            if initializer_words and idx < len(initializer_words):
                if len(initializer_words[idx].split(" ")) == 1:
                    init_word_token, init_word_token_all = get_token_for_string(initializer_words[idx])
                    with torch.no_grad():
                        init_word_embedding = get_embedding_for_tkn(init_word_token_all.to(device))
                        # init_word_embedding = get_embedding_for_tkn(initializer_words)
                else:
                    print("More than one word! Using random initialization!!!!!!!!!!!!")
                    init_word_embedding = torch.rand(size=(token_dim,), requires_grad=True)
                    self.random_start = True

                token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
            else:
                self.random_start = True
                token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
            
            self.string_to_token_dict[placeholder_string] = token
            self.string_to_param_dict[placeholder_string] = token_params
            self.initial_embeddings[placeholder_string] = torch.nn.Parameter(token_params.clone().detach(), requires_grad=False)

        
    def restart(self, use_initial_embedding=True):
        # 用于embedding的正交约束，保留字典用于未来多任务消融扩展
        for string, token in self.string_to_param_dict.items():
            if string not in self.string_to_embedding_pool.keys():
                self.string_to_embedding_pool[string] = token.data.clone() / token.data.norm()
                self.sampled_pool[string] = token.data.clone().unsqueeze(0)
            else:
                last_token = token.data.clone() + \
                    (self.string_to_embedding_pool[string].detach() * self.weights[:self.string_to_embedding_pool[string].shape[0]].to("cuda")).sum(0).unsqueeze(0)
                
                self.sampled_pool[string] = torch.cat([self.sampled_pool[string], last_token.unsqueeze(0)], dim=0)
                self.string_to_embedding_pool[string] = torch.cat([self.string_to_embedding_pool[string], token.data.clone() / token.data.norm()], dim=0)

            self.string_to_embedding_pool[string].requires_grad = False
            self.sampled_pool[string].requires_grad = False
        
        self.weights.data = self.inital_weights.clone()
        
        if use_initial_embedding:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = self.initial_embeddings[string].data.clone()
        else:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = \
                    torch.nn.Parameter(torch.rand_like(self.initial_embeddings[string].data), requires_grad=True)
    
    @torch.no_grad()
    def get_embedding(self):
        ret = {}
        for n, v in self.string_to_param_dict.items():
            if n in self.string_to_embedding_pool.keys():
                v = v + \
                    (self.string_to_embedding_pool[n].detach() * self.weights[:self.string_to_embedding_pool[n].shape[0]].to(v.device)).sum(0).unsqueeze(0)
            ret[n] = v
        return ret



class EmbeddingManagerRandom(nn.Module):
    def __init__(
            self, **kwargs
    ):
        super().__init__()
        self.config = kwargs
        self.start(**self.config)
        self.sampled_pool = torch.nn.ParameterDict()
        self.random_start = False

    def forward(
            self,
            tokenized_text,
            embedded_text,
            pos_condition=True,
            replace=False,  # True: replace, False: decompose
            sample=-1,
            zero=False,
            **kwargs
    ):
        b, n, device = *tokenized_text.shape, tokenized_text.device

        embedded_text = embedded_text.detach()

        for placeholder_string, placeholder_token in self.string_to_token_dict.items():
            if sample == -1:
                placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
            else:
                placeholder_embedding = self.sampled_pool[placeholder_string][sample, :, :].to(device).squeeze(0)
            if not pos_condition:
                placeholder_embedding = - placeholder_embedding
            
            if zero:
                placeholder_embedding = placeholder_embedding * 0
            
            # if not replace:
            #     mask = ~ ((tokenized_text == 49406) + (tokenized_text == 49407)).unsqueeze(-1)
            #     if placeholder_embedding.shape[0] > 1:
            #         placeholder_embedding = placeholder_embedding[0:1]
            #     placeholder_embedding = placeholder_embedding / placeholder_embedding.norm()
            #     original_embedding = embedded_text
            #     embedded_text = original_embedding - 5 * (original_embedding @ placeholder_embedding.T).clamp(0) * mask * placeholder_embedding
            #     print((original_embedding @ placeholder_embedding.T).clamp(0))
            #     return embedded_text

            if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
                placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
                embedded_text[placeholder_idx] = placeholder_embedding
            else: # otherwise, need to insert and keep track of changing indices
                if self.progressive_words:
                    self.progressive_counter += 1
                    max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
                else:
                    max_step_tokens = self.max_vectors_per_token

                num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)

                placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))

                if placeholder_rows.nelement() == 0:
                    continue

                sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
                sorted_rows = placeholder_rows[sort_idx]

                for idx in range(len(sorted_rows)):
                    row = sorted_rows[idx]
                    col = sorted_cols[idx]

                    new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
                    new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]

                    embedded_text[row]  = new_embed_row
                    tokenized_text[row] = new_token_row

        return embedded_text
    
    def start(self, 
            embedder="stable_diffusion",
            placeholder_strings=None,
            initializer_words=None,
            per_image_tokens=False,
            num_vectors_per_token=1,
            progressive_words=False,
            device="cuda:0",
            **kwargs
            ):
        self.string_to_token_dict = {}
        
        self.string_to_param_dict = nn.ParameterDict()

        self.initial_embeddings = nn.ParameterDict() # These should not be optimized

        self.progressive_words = progressive_words
        self.progressive_counter = 0

        self.max_vectors_per_token = num_vectors_per_token

        if isinstance(embedder, str):
            if embedder == "stable_diffusion": # using Stable Diffusion's CLIP encoder
                embedder = FrozenCLIPEmbedder().to(device)
                self.is_clip = True
                get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
                get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer)
                token_dim = 768
            else: # using LDM's BERT encoder
                embedder = BERTEmbedder(n_embed=1280, n_layer=32).to(device)
                self.is_clip = False
                get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
                get_embedding_for_tkn = embedder.transformer.token_emb
                token_dim = 1280
        elif isinstance(embedder, FrozenCLIPEmbedder):
            self.is_clip = True
            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
            get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)  # embedder.transformer.text_model.embeddings
            token_dim = 768
        elif isinstance(embedder, BERTEmbedder):
            self.is_clip = False
            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
            get_embedding_for_tkn = embedder.transformer.token_emb
            token_dim = 1280

        if per_image_tokens:
            placeholder_strings.extend(per_img_token_list)

        for idx, placeholder_string in enumerate(placeholder_strings):
            
            token, token_all = get_token_for_string(placeholder_string)

            initializer_words = None

            if initializer_words and idx < len(initializer_words):
                if len(initializer_words[idx].split(" ")) == 1:
                    init_word_token, init_word_token_all = get_token_for_string(initializer_words[idx])
                    with torch.no_grad():
                        init_word_embedding = get_embedding_for_tkn(init_word_token_all.to(device))
                        # init_word_embedding = get_embedding_for_tkn(initializer_words)
                else:
                    print("More than one word! Using random initialization!!!!!!!!!!!!")
                    init_word_embedding = torch.rand(size=(token_dim,), requires_grad=True)
                    self.random_start = True

                token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
            else:
                self.random_start = True
                token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
            
            self.initial_embeddings[placeholder_string] = torch.nn.Parameter(token_params.clone().detach(), requires_grad=False)
            self.string_to_token_dict[placeholder_string] = token
            self.string_to_param_dict[placeholder_string] = token_params

        
    def restart(self, use_initial_embedding=True):
        # 用于embedding的正交约束，保留字典用于未来多任务消融扩展
        for string, token in self.string_to_param_dict.items():
            if string not in self.sampled_pool.keys():
                self.sampled_pool[string] = token.data.clone().unsqueeze(0)
            else:
                self.sampled_pool[string] = torch.cat([self.sampled_pool[string], token.data.clone().unsqueeze(0)], dim=0)

            self.sampled_pool[string].requires_grad = False
        
        if use_initial_embedding:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = self.initial_embeddings[string].data.clone()
        else:
            for string, token in self.string_to_param_dict.items():
                self.string_to_param_dict[string].data = \
                    torch.nn.Parameter(torch.rand_like(self.initial_embeddings[string].data), requires_grad=True)
    
    @torch.no_grad()
    def get_embedding(self):
        ret = {}
        for n, v in self.string_to_param_dict.items():
            ret[n] = v
        return ret


class EmbeddingWraper(nn.Module):
    def __init__(self, name: str, config: dict, logger, initializer_words=None, embedder="stable_diffusion"):
        super().__init__()
        self.name = name + "_" + initializer_words[0]
        self.config = config
        self.logger = logger
        self.device = config["device"]

        model_config = json.load(open(config["model_config"], "r"))
        # model_config["initializer_words"] = initializer_words

        model_config["embedder"] = embedder

        if "modify" in self.config.keys():
            for k, v in self.config["modify"].items():
                model_config[k] = v
                print(k, v)

        if "embedding_type" in self.config.keys():
            if self.config["embedding_type"] == "reweight":
                self.model = EmbeddingManagerReorgan(**model_config)
            elif self.config["embedding_type"] == "reweight_random":
                self.model = EmbeddingManagerReorganRandom(**model_config)
            elif self.config["embedding_type"] == "random":
                self.model = EmbeddingManagerRandom(**model_config)
        else:
            self.model = EmbeddingManager(**model_config)
    
    def save_checkpoint(self, epoch, path):
        if not os.path.exists(os.path.join(path, "checkpoints")):
            os.makedirs(os.path.join(path, "checkpoints"))
        state_dict = {"string_to_token": self.model.string_to_token_dict, "string_to_param": self.model.string_to_param_dict}
        if hasattr(self.model, "weights"):
            state_dict["weights"] = self.model.weights
            state_dict["string_to_embedding_pool"] = self.model.string_to_embedding_pool
        if hasattr(self.model, "sampled_pool"):
            state_dict["sampled_pool"] = self.model.sampled_pool
        state_dict["embedding"] = self.model.get_embedding()
        filename = os.path.join(path, "checkpoints", self.name.replace(" ", "_") + "_epoch_{}.pth".format(epoch))
        torch.save(state_dict, filename)

    def init(self, config: dict):
        assert "pretrained_model" in config.keys()
        if os.path.isfile(config['pretrained_model']):
            state_dict = torch.load(config['pretrained_model'], map_location=torch.device('cuda'))
            self.model.string_to_token_dict = state_dict["string_to_token"]
            self.model.string_to_param_dict = state_dict["string_to_param"]
            for n in state_dict.keys():
                if hasattr(self.model, n):
                    setattr(self.model, n, state_dict[n])
                    self.logger.info("EmbeddingWraper: Set {} success!".format(n))
            self.logger.info("EmbeddingWraper: Load the checkpoint {}".format(config['pretrained_model']))
        else:
            self.logger.info("EmbeddingWraper: Load pretrained model fail! ({})".format(self.config['pretrained_model']))
            raise FileNotFoundError
    
    def get_embeddings(self):
        return self.model.string_to_param_dict
    
    def restart(self, use_initial_embedding=True):
        self.model.restart(use_initial_embedding)
    
    def init_from_log(self, log_path, file_path=None):
        if file_path:
            self.init({"pretrained_model": file_path})
            return
        
        files = glob(os.path.join(log_path, "checkpoints", self.name + "_epoch_*.pth"))
        if len(files) == 0:
            self.logger.info("{} checkpoints file not found error!!!!".format(self.name))
            raise FileNotFoundError
        files = sorted(files, key=lambda x: int(x.split("_epoch_")[-1].split(".pth")[0]))[-1]
        self.init({"pretrained_model": files})
    
    def parameters(self, recurse: bool = True):
        return self.model.parameters(recurse=recurse)
    
    def get_embedding_tool(self):
        return self.model.string_to_embedding_pool