## 
import torch
import torch.nn.functional as F
import copy


def parse_substitution(tokenizer, embeddings=None):
    res_dict = {}  ## sentence -> substitutes
    sentence = []
    words = []
    seed = None
    scale = None
    path = '/nfs/data/$user/projects/crash/diffusion/diffusers/examples/textual_inversion/attack_concepts/substitute.txt'
    # path = '/nfs/data/$user/projects/crash/diffusion/diffusers/examples/textual_inversion/attack_concepts/MagicPrompt.txt'

    with open(path, 'r') as f:
        lines = f.readlines()
        lines.append('')  ## so that we do not miss the last sentence
        lid = 0
        while lid < len(lines):
            line = lines[lid]
            line = line.strip()

            if line == '':
                lid += 1
                continue
            elif '->' not in line:  ## sentence block
                ori_sentence = line; lid += 1
                ## augment a blank space before comma, so that comma can be treated as a word
                sentence = ''
                for c in ori_sentence:
                    if c == ',':
                        sentence += ' ' + c
                    else:
                        sentence += c

                if 'seed' in lines[lid] or 'scale' in lines[lid]:
                    for param_val in lines[lid].strip().split(','):
                        param, val = param_val.split('=')
                        if param == 'seed':
                            seed = int(val)
                        if param == 'scale':
                            scale = float(val)
                    lid += 1

                ## get the sub_words list
                words = []
                while lines[lid].strip() != '':
                    line = lines[lid].strip()
                    ## parse line
                    ori_word, sub_words = line.split(' -> ')
                    sub_words = [ori_word] + sub_words.split(', ')

                    ## register value
                    ori_id = tokenizer(ori_word).input_ids[1:-1]
                    if len(ori_id) > 1:  ## got assigned more than 1 tokens, do not substitute
                        words.append({
                            'ori': ori_word,
                            'do_sub': False,
                            'sub_words': ori_word,
                            'sub_ids': ori_id,
                            'embeddings': embeddings[ori_id]
                        })
                    else:
                        ## got assigned more than 1 tokens, do not substitute
                        sub_words = [w for w in sub_words if len(tokenizer(w).input_ids) == 3]
                        sub_ids = [tokenizer(w).input_ids[1] for w in sub_words]
                        sub_embeds = [embeddings[id] for id in sub_ids]
                        words.append({
                            'ori': ori_word,
                            'do_sub': True,
                            'sub_words': sub_words,
                            'sub_ids': sub_ids,
                            'embeddings': torch.stack(sub_embeds)
                        })
                    
                    lid += 1
                
                ## match sub_words list with the sentence
                cid = 0
                all_words = []
                while cid < len(sentence):
                    matched = False
                    for word_dict in words:
                        ori_word = word_dict['ori']
                        if ori_word.lower() == sentence[cid:cid + len(ori_word)].lower():  # matched
                            all_words.append(copy.deepcopy(word_dict))
                            cid += len(ori_word) + 1  # add blank space
                            matched = True
                            break
                    
                    if not matched: # unmatched
                        ori_word = sentence[cid:cid + sentence[cid:].find(' ')]
                        cid += len(ori_word) + 1
                        ori_id = tokenizer(ori_word).input_ids[1:-1]
                        all_words.append({
                            'ori': ori_word,
                            'do_sub': False,
                            'sub_words': ori_word,
                            'sub_ids': ori_id,
                            'embeddings': embeddings[ori_id]
                        })
                
                ## register prompt data into result dict
                res_dict[sentence] = {
                    'all_words': all_words,
                    'seed': seed,
                    'scale': scale,
                }
                seed, scale = None, None

    return res_dict


def process_substitution(sub_path):
    substitutes = {}  ## sentence -> substitutes
    sentence = []

    with open(sub_path, 'r') as f:
        lines = f.readlines()
        lines.append('')  ## so that we do not miss the last sentence
        lid = 0
        while lid < len(lines):
            line = lines[lid]
            line = line.strip()

            if line == '':
                lid += 1
                continue
            elif '->' not in line:  ## sentence block
                ori_sentence = line; lid += 1
                ## augment a blank space before comma, so that comma can be treated as a word
                sentence = ''
                for c in ori_sentence:
                    if c == ',':
                        sentence += ' ' + c
                    else:
                        sentence += c

                if 'seed' in lines[lid] or 'scale' in lines[lid]:
                    for param_val in lines[lid].strip().split(','):
                        param, val = param_val.split('=')
                        if param == 'seed':
                            seed = int(val)
                        if param == 'scale':
                            scale = float(val)
                    lid += 1

                ## get the sub_words list
                substitute = {}
                while lines[lid].strip() != '':
                    line = lines[lid].strip()
                    ## parse line
                    ori_word, sub_words = line.split(' -> ')
                    sub_words = sub_words.split(', ')

                    ## register value
                    substitute[ori_word] = sub_words

                    lid += 1
                
                ## register prompt data into result dict
                substitutes[sentence.lower()] = {
                    'sub': substitute.lower(),
                    'prompt_id': None,
                    'prompt': sentence.lower()
                }


    return substitutes

