import torch
import torch.nn.functional as F
import copy
import string
import numpy as np

from utils.attack_utils import get_constraint_fn
from utils.exp_utils import json_load, text_load
from .paths import NEG_PROMPTS_LIBRARY, NEG_PROMPTS_LIBRARY_V2


def substitute_sample(sub_dict, args, argmax=False, hard_sample=True):
    """ return both mixed_inputs_embeds and sampled text (argmax) """
    ## sample sub words
    adv_prompt = []
    for sub_word_dict in sub_dict['all_words']:
        do_sub = sub_word_dict['do_sub']
        if do_sub:  ## substitute
            sub_words = sub_word_dict['sub_words']
            sub_log_coeff = sub_word_dict['log_coeff']
            if argmax:
                # sub_id = sub_log_coeff.argmax(dim=-1).item()
                max_indices = torch.where(sub_log_coeff[0] == torch.max(sub_log_coeff))[0]
                sub_id = np.random.choice(max_indices.cpu().data)
            else:
                while True:
                    sub_coeff = F.gumbel_softmax(sub_log_coeff.unsqueeze(0).repeat(args.batch_size, 1, 1),
                                                 tau=args.sample_temp, hard=hard_sample)
                    sub_id = sub_coeff.argmax(dim=-1).item()
                    if not args.reject or sub_log_coeff[..., sub_id].item() > 0:
                        break

            sub_word = sub_words[sub_id]
        else:
            sub_word = sub_word_dict['ori_word']
        adv_prompt.append(sub_word)
    adv_prompt = ' '.join(adv_prompt)

    return adv_prompt


def substitute2embed_dict(substitute, tokenizer, embeddings):
    ori_prompt = substitute['prompt']
    domain = {}
    if 'pos' in substitute:
        domain['pos'] = substitute2embed_dict_pos(ori_prompt, substitute['pos'], tokenizer, embeddings)
    if 'neg' in substitute:
        domain['neg'] = substitute2embed_dict_neg(ori_prompt, substitute['neg'], tokenizer, embeddings)

    return domain


def substitute2embed_dict_neg(ori_prompt, negs, tokenizer, embeddings):
    """ neg opp version, only keep matched words and do not use ori_word """
    """ map json substitute (one prompt) to sub_dict """
    """ will also remove duplicate subs """

    ## get the sub_words list
    words = []
    for ori_word in negs:
        sub_words = negs[ori_word]
        # remove duplicate subs
        effective_sub_words = [w for w in sub_words \
                               if w.lower() != ori_word.lower() and w != '']
        if len(effective_sub_words) <= 0:
            print('no effective sub words found in {}'.format(sub_words))
            continue

        ## register value
        sub_words = [w for w in sub_words if len(tokenizer(w).input_ids) <= 3]

        sub_ids = [tokenizer(w).input_ids[1] for w in sub_words]
        sub_embeds = [embeddings[id] for id in sub_ids]
        words.append({
            'ori_word': ori_word,
            'do_sub': True,
            'sub_words': sub_words,
            'sub_ids': sub_ids,
            'embeddings': torch.stack(sub_embeds)
        })

        all_words = words

        ## match sub_words list with the sentence
    
    ## register prompt data into result dict
    res_dict = {
        'prompt': ori_prompt,
        'all_words': all_words,
        'seed': None,
        'scale': None,
    }

    return res_dict


def substitute2embed_dict_pos(ori_prompt, subs, tokenizer, embeddings):
    """ map json substitute (one prompt) to sub_dict """
    """ will also remove duplicate subs """
    ## augment a blank space before comma, so that comma can be treated as a word
    prompt = ''
    for c in ori_prompt:
        if c in string.punctuation:
            prompt += ' ' + c
        else:
            prompt += c
    prompt += ' '

    ## get the sub_words list
    words = []
    for ori_word in subs:
        sub_words = subs[ori_word]
        # remove duplicate subs
        sub_words = [w for w in sub_words if w.lower() != ori_word.lower()]
        sub_words = [ori_word] + sub_words

        ## register value
        ori_id = tokenizer(ori_word).input_ids[1:-1]
        if len(ori_id) > 1:  ## got assigned more than 1 tokens, do not substitute
            words.append({
                'ori_word': ori_word,
                'do_sub': False,
                'sub_words': [ori_word],
                'sub_ids': ori_id,
                'embeddings': embeddings[ori_id]
            })
        else:
            ## got assigned more than 1 tokens, do not substitute
            sub_words = [w for w in sub_words if len(tokenizer(w).input_ids) == 3]
            sub_ids = [tokenizer(w).input_ids[1] for w in sub_words]
            sub_embeds = [embeddings[id] for id in sub_ids]
            words.append({
                'ori_word': ori_word,
                'do_sub': True,
                'sub_words': sub_words,
                'sub_ids': sub_ids,
                'embeddings': torch.stack(sub_embeds)
            })
        
        ## match sub_words list with the sentence
        cid = 0
        all_words = []
        while cid < len(prompt):
            matched = False
            for word_dict in words:
                ori_word = word_dict['ori_word']
                if ori_word.lower() == prompt[cid:cid + len(ori_word)].lower():  # matched
                    all_words.append(copy.deepcopy(word_dict))
                    cid += len(ori_word) + 1  # add blank space
                    matched = True
                    break
            
            if not matched: # unmatched
                ori_word = prompt[cid:cid + prompt[cid:].find(' ')]
                cid += len(ori_word) + 1
                ori_id = tokenizer(ori_word).input_ids[1:-1]
                all_words.append({
                    'ori_word': ori_word,
                    'do_sub': False,
                    'sub_words': ori_word,
                    'sub_ids': ori_id,
                    'embeddings': embeddings[ori_id]
                })
        
    ## register prompt data into result dict
    res_dict = {
        'prompt': ori_prompt,
        'all_words': all_words,
        'seed': None,
        'scale': None,
    }

    return res_dict


def negative_prompt_substitute(ori_prompt, length, version):
    """ length: length of negative prompt """
    def load_neg_prompt(path):
        words = []
        with open(path, 'r') as f:
            for line in f.readlines():
                words.append(line.strip())
        return words

    ## load negative prompts
    words_v1 = load_neg_prompt(NEG_PROMPTS_LIBRARY)
    words_v2 = load_neg_prompt(NEG_PROMPTS_LIBRARY_V2)

    ## make substitutes
    substitutes = {
        'sub': {},
        'prompt_id': None,
        'prompt': ori_prompt,
    }
    if 'v1' in version:  ## copy
        for wid in range(length):
            substitutes['sub'][f'NPLib-{wid}'] = copy.deepcopy(words_v1)
    elif 'v2' in version:
        words_splits = np.array_split(words_v2, length)
        for wid, words in enumerate(words_splits):
            substitutes['sub'][f'NPLib-{wid}'] = words.tolist()
    elif 'v3' in version:
        ## load v2
        words_splits = np.array_split(words_v2 + words_v1, len(words_v2 + words_v1 * 3) // 50)
        for wid, words in enumerate(words_splits):
            substitutes['sub'][f'NPLib-{wid}'] = words.tolist()
    return substitutes


def validate_synonyms(substitutes, constraint, thresh, tokenizer, text_encoder):
    """ use bert score to filter sub_dict """
    ori_prompt = substitutes['prompt']
    subs = substitutes['pos']

    constraint_fn = get_constraint_fn(constraint, tokenizer, text_encoder, ori_prompt)

    subs_filtered = {}
    for ori_word in subs: ## each position
        sub_words = subs[ori_word]

        sub_words_filtered = []  ## indices of sub_words to delete
        for sub_word in sub_words:
            adv_prompt = ori_prompt.lower().replace(ori_word, sub_word)
            bert_score = constraint_fn(adv_prompt)
            if bert_score >= thresh:
                sub_words_filtered.append(sub_word)
            print('{:.8f} {}'.format(bert_score, adv_prompt))
        print('='*10)
        
        if len(sub_words_filtered) > 0:
            subs_filtered[ori_word] = sub_words_filtered

    substitutes['pos'] = subs_filtered
    return substitutes


def is_float(string):
    try:
        float(string)
        return True
    except ValueError:
        return False


def load_prompts(path):
    if '.json' in path:
        prompts = json_load(path)
    elif '.txt' in path:
        prompts = text_load(path)
    
    return prompts


def sub_to_lower(substitutes):
    """ convert every string in substitutes to lower case. """

    return None