import logging
import torch
import numpy as np
import torch.nn.functional as F

from torch.distributions import Dirichlet
from prompts.common import substitute2embed_dict
from attackers.gbda.gbda_adam import Adam
from attackers.gbda.gbda_rmsprop import RMSprop
from attackers.gbda.gbda_signsgd import SignSGD


def gumbel_mean(log_coeff, temp, hard):
    """ empirical mean of gumbel softmax """
    assert(len(log_coeff.shape) == 3)
    log_coeff = log_coeff.repeat(500, 1, 1)

    if hard:
        return torch.softmax(log_coeff / temp)
    else:
        coeff = F.gumbel_softmax(log_coeff, tau=temp, hard=False)
        return coeff.mean(0, keepdim=True)


def padding(inputs_embeds, pad_inputs_embeddings, max_length):
    pad_length = max_length - inputs_embeds.shape[1]
    if pad_length > 0:
        pad_inputs_embeddings = pad_inputs_embeddings.repeat((1, pad_length, 1))
        inputs_embeds = torch.cat([inputs_embeds, pad_inputs_embeddings], dim=1)
    else:
        inputs_embeds = inputs_embeds[:, :max_length, :]

    return inputs_embeds


def get_ada_log_coeff(sub_embeddings, args):
    log_coeff = torch.zeros(1, len(sub_embeddings)).to(sub_embeddings.device)

    if 'init_method' not in args:
        args.init_method = 'default'
    if 'init_coeff' not in args:
        args.init_coeff = 0

    if args.init_method == 'default':
        log_coeff[0, 0] = args.init_coeff

    elif 'ada' in args.init_method:  ## ada-<init>
        thresh = float(args.init_method.split('-')[1])
        init_coeff = 0
        while True:
            log_coeff[0, 0] = init_coeff

            coeffs = []
            for _ in range(1000):
                coeff = F.gumbel_softmax(log_coeff, hard=False)
                coeffs.append(coeff)
            mean_coeff = torch.cat(coeffs, dim=0).mean(dim=0)
            
            if mean_coeff.max() > max(thresh, 1 / len(sub_embeddings)):
                break
            else:
                init_coeff += 0.1
        print('initial coeff:', log_coeff)

    else:
        raise ValueError(args.init_method)

    log_coeff.requires_grad = True
    return log_coeff


class PromptOptimizer():
    """ a class to manage 1. embeddings 2. prompt distribution """
    def __init__(self, text_encoder, tokenizer, substitutes, args):
        self.device = text_encoder.device
        self.args = args
        
        ## set model
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer

        ## extract embeds
        with torch.no_grad():
            self.embeddings = text_encoder.get_input_embeddings()(torch.arange(0, tokenizer.vocab_size).long().to(self.device))
        self.bos_input_embeds = self.embeddings[tokenizer.bos_token_id]
        self.eos_input_embeds = self.embeddings[tokenizer.eos_token_id]
        self.model_max_length = tokenizer.model_max_length
        self.ori_prompt = substitutes['prompt']
        input_ids = tokenizer(self.ori_prompt, padding='max_length', truncation=True, return_tensors='pt').input_ids
        self.ori_embeds = text_encoder(input_ids.to(self.device))[0]

        ## init substitutes dict
        self.domain = substitute2embed_dict(substitutes, self.tokenizer, self.embeddings)
        self.log_coeffs = self._init_log_coeffs(self.domain)
        self.optimizer = self._init_optimizer()

    def _init_optimizer(self):
        if 'opt' not in self.args:
            optimizer = None  ## uniform distribution, for pure EA
        elif self.args.opt == 'adam':
            optimizer = Adam(self.log_coeffs, lr=self.args.lr, betas=self.args.betas)
        elif self.args.opt == 'rmsprop':
            optimizer = RMSprop(self.log_coeffs, lr=self.args.lr, momentum=self.args.betas[0])
        elif self.args.opt == 'signsgd':
            optimizer = SignSGD(self.log_coeffs, lr=self.args.lr, momentum=self.args.betas[0])
        return optimizer


    def _init_log_coeffs(self, domain):
        log_coeffs = []
        for _type in domain:
            sub_dict = domain[_type]

            for sub_word_dict in sub_dict['all_words']:
                do_sub = sub_word_dict['do_sub']
                sub_embeddings = sub_word_dict['embeddings']
                if do_sub:  ## substitute (1, sub_choices)
                    log_coeff = get_ada_log_coeff(sub_embeddings, self.args)
                    sub_word_dict['log_coeff'] = log_coeff
                    log_coeffs.append(log_coeff)
                    assert id(sub_word_dict['log_coeff']) == id(log_coeffs[-1])
        return log_coeffs

    def get_max_grad_norm(self):
        return max(grad.norm().item() for grad in self._get_log_coeffs_grads())

    def _get_log_coeffs_grads(self):
        return [log_coeff.grad for log_coeff in self.log_coeffs]

    def _inplace_clip(self, params, min, max):
        for param in params:
            param.data.copy_(torch.clip(param, min, max))
    
    def _inplace_sub(self, ps1, ps2):
        assert len(ps1) == len(ps2)
        for p1, p2 in zip(ps1, ps2):
            p1.data.sub_(p2)
    
    def set_log_coeffs(self, new_log_coeffs):
        if isinstance(new_log_coeffs[0], list):  # multiple log_coeffs
            new_log_coeffs = list(map(list, zip(*new_log_coeffs)))
            assert len(self.log_coeffs) == len(new_log_coeffs)
            for idx, new_log_coeffs in enumerate(new_log_coeffs):
                self.log_coeffs[idx] = torch.cat([*new_log_coeffs], dim=0)
            ## update sub_dict as well
            idx = 0
            for _type in self.domain:
                sub_dict = self.domain[_type]
                for sub_word_dict in sub_dict['all_words']:
                    if sub_word_dict['do_sub']:
                        assert self.log_coeffs[idx].shape[-1] == sub_word_dict['log_coeff'].shape[-1]
                        sub_word_dict['log_coeff'] = self.log_coeffs[idx]; idx += 1
            assert idx == len(self.log_coeffs)
        else:
            assert len(self.log_coeffs) == len(new_log_coeffs)
            for log_coeff, new_log_coeff in zip(self.log_coeffs, new_log_coeffs):
                log_coeff.data.copy_(new_log_coeff)

    def get_log_coeffs(self):
        assert not isinstance(self.log_coeffs[0], list), 'not implemented'
        log_coeffs_copy = []
        for log_coeff in self.log_coeffs:
            log_coeffs_copy.append(log_coeff.clone())
        return log_coeffs_copy

    def display_log_coeffs(self):
        for log_coeff in self.log_coeffs: logging.info(log_coeff.cpu().detach())

    def sample_prompt(self, argmax=False, return_ids=False):
        ## sample sub words
        all_prompt = {'pos': None, 'neg': None}
        all_sub_ids = []  ## flattened, for EA
        for _type in self.domain:
            sub_dict = self.domain[_type]

            prompt = []
            sub_ids = []
            for sub_word_dict in sub_dict['all_words']:
                do_sub = sub_word_dict['do_sub']
                if do_sub:  ## substitute
                    sub_words = sub_word_dict['sub_words']
                    sub_log_coeff = sub_word_dict['log_coeff']
                    if argmax:
                        max_indices = torch.where(sub_log_coeff[0] == torch.max(sub_log_coeff))[0]
                        sub_id = np.random.choice(max_indices.cpu().data)
                    else:
                        sub_coeff = F.gumbel_softmax(sub_log_coeff.unsqueeze(0).repeat(self.args.batch_size, 1, 1),
                                                    tau=self.args.sample_temp)
                        sub_id = sub_coeff.argmax(dim=-1).item()

                    sub_word = sub_words[sub_id]
                    sub_ids.append(sub_id)
                else:
                    sub_word = sub_word_dict['ori_word']
                prompt.append(sub_word)
            prompt = ' '.join(prompt)

            if _type == 'pos':
                all_prompt['pos'] = prompt
            elif _type == 'neg':
                ## TODO under dev ##
                all_prompt['neg'] = prompt.replace(' ', ',')
            else:
                raise ValueError(_type)
        
            all_sub_ids += sub_ids

        if return_ids:
            return all_sub_ids
        return all_prompt['pos'], all_prompt['neg']

    def step(self):
        ## gradient clipping
        self._inplace_clip(self._get_log_coeffs_grads(), -self.args.clip_grad, self.args.clip_grad)
        ## update
        adam_updates = self.optimizer.get_update() ## include lr
        self._inplace_sub(self.log_coeffs, adam_updates)
        self._inplace_clip(self.log_coeffs, self.args.min_coeff, self.args.max_coeff)
        self.optimizer.zero_grad()

    def get_mixed_embeds(self, temp, use_mean=False):
        all_mixed_embeds = self.get_mixed_inputs_embeds(temp, use_mean)
        return all_mixed_embeds['pos'], all_mixed_embeds['neg']

    def get_mixed_inputs_embeds(self, temp, use_mean=False):
        all_mixed_embeds = {'pos':None, 'neg':None}
        for _type in self.domain:
            sub_dict = self.domain[_type]

            mixed_embeds = self.bos_input_embeds.repeat(self.args.batch_size, 1, 1)
            for sub_word_dict in sub_dict['all_words']:
                do_sub = sub_word_dict['do_sub']
                sub_embeddings = sub_word_dict['embeddings']
                if do_sub:  ## substitute
                    sub_log_coeff = sub_word_dict['log_coeff']
                    sub_coeff = self._mixer_sampler(sub_log_coeff, temp, use_mean)
                    with torch.autocast(device_type='cuda', dtype=torch.float16):
                        sub_mixed_inputs_embeds = sub_coeff @ sub_embeddings
                else:
                    sub_mixed_inputs_embeds = sub_embeddings.repeat(self.args.batch_size, 1, 1)
                mixed_embeds = torch.cat([mixed_embeds, sub_mixed_inputs_embeds], dim=1)
            mixed_embeds = torch.cat([mixed_embeds, self.eos_input_embeds.repeat(self.args.batch_size, 1, 1)], dim=1)
            mixed_embeds = padding(mixed_embeds, self.eos_input_embeds, self.model_max_length)

            if _type == 'pos':
                all_mixed_embeds['pos'] = mixed_embeds
            elif _type == 'neg':
                all_mixed_embeds['neg'] = mixed_embeds
            else:
                raise ValueError(_type)

        if all_mixed_embeds['pos'] is None:
            all_mixed_embeds['pos'] = self.ori_embeds

        return all_mixed_embeds

    def _mixer_sampler(self, sub_log_coeff, temp, use_mean):
        if self.args.mixer == 'gumbel':
            if use_mean:
                print('// use mean')
                sub_coeff = gumbel_mean(sub_log_coeff.unsqueeze(0).repeat(self.args.batch_size, 1, 1), temp)
            else:
                sub_coeff = F.gumbel_softmax(sub_log_coeff.unsqueeze(0).repeat(self.args.batch_size, 1, 1), tau=temp)
        elif self.args.mixer == 'dirichlet':
            concentration = F.elu(sub_log_coeff.repeat(self.args.batch_size, 1) / temp) + 1
            dirichlet_dist = Dirichlet(concentration)
            sub_coeff = dirichlet_dist.rsample()
            sub_coeff = sub_coeff.unsqueeze(1)
        else:
            raise ValueError(self.args.mixer)
    
        return sub_coeff