"""
store util functions for attack pipelien
agnostic to attack methods
"""
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GPT2LMHeadModel, GPT2Config


def load_gpt2_from_dict(dict_path, output_hidden_states=False):
    def embedding_from_weights(w):
        layer = torch.nn.Embedding(w.size(0), w.size(1))
        layer.weight.data = w
        return layer

    state_dict = torch.load(dict_path)['model']

    config = GPT2Config(
        vocab_size=30522,
        n_embd=1024,
        n_head=8,
        activation_function='relu',
        n_layer=24,
        output_hidden_states=output_hidden_states
    )
    model = GPT2LMHeadModel(config)
    model.load_state_dict(state_dict)
    # The input embedding is not loaded automatically
    model.set_input_embeddings(embedding_from_weights(state_dict['transformer.wte.weight'].cpu()))

    return model


class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cut_power=1.0):
        super().__init__()

        self.cut_size = cut_size
        self.cut_power = cut_power

    def forward(self, pixel_values, num_cutouts):
        sideY, sideX = pixel_values.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(num_cutouts):
            size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
            cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
        return torch.cat(cutouts)


def get_inputs_embeddings(input_ids, embeddings, batch_size):
    coeffs = torch.zeros((batch_size, len(input_ids), embeddings.shape[0]), dtype=embeddings.dtype).to(embeddings.device)
    for sid, input_id in enumerate(input_ids):
        coeffs[:, sid, input_id].data.fill_(1)
    ori_embeds = (coeffs @ embeddings[None, :, :]) # B x T x D
    ori_inputs_embeds = torch.zeros_like(ori_embeds, dtype=embeddings.dtype)
    ori_inputs_embeds.data.copy_(ori_embeds.data)
    return ori_inputs_embeds


#### constraints
def bert_score(refs, cands, add_bos, add_eos, weights=None):
    # remove first and last tokens
    if add_bos:
        refs = refs[:, 1:, 1:]
        cands = cands[:, 1:, 1:]
    if add_eos:
        refs = refs[:, :-1, :-1]
        cands = cands[:, :-1, :-1]
    refs_norm = refs / refs.norm(2, -1).unsqueeze(-1)
    if weights is not None:
        refs_norm *= weights[:, None]
    else:
        refs_norm /= refs.size(1)
    cands_norm = cands / cands.norm(2, -1).unsqueeze(-1)
    cosines = refs_norm @ cands_norm.transpose(1, 2)
    ## TODO under dev ## you should remove it at the beginning
    # remove first and last tokens; only works when refs and cands all have equal length (!!!)
    # cosines = cosines[:, 1:-1, 1:-1]
    R = cosines.max(-1)[0].sum(1)

    return R


class EmbedDist():
    def __init__(self, tokenizer, text_encoder, metric):
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.device = text_encoder.device
        self.metric = metric
    
    def get_embed(self, prompt):
        text_inputs = self.tokenizer(prompt, return_tensors="pt")
        text_input_ids = text_inputs.input_ids.to(self.device)

        ## whether this language model requires eos and bos tokens
        self.add_bos = text_input_ids[:, 0]  == self.tokenizer.bos_token_id
        self.add_eos = text_input_ids[:, -1] == self.tokenizer.eos_token_id

        embed = self.text_encoder(text_input_ids)
        if self.metric == 'bert_score_clip':
            embed = embed[0]
        elif self.metric in ['bert_score_gpt2', 'bert_score_gpt2_large']:
            embed = embed.hidden_states[-1]
        else:
            raise ValueError(self.metric)
        embed = embed / embed.norm(p=2, dim=-1, keepdim=True)
        return embed

    @torch.no_grad()
    def __call__(self, prompt1, prompt2):
        embed1 = self.get_embed(prompt1)
        embed2 = self.get_embed(prompt2)
        dist = bert_score(embed1, embed2, self.add_bos, self.add_eos)
        return dist.item()


def get_constraint_fn(constraint, tokenizer, text_encoder, ori_prompt):
    device = text_encoder.device

    if constraint == 'none':
        constraint_fn = lambda: None
    elif constraint == 'bert_score_clip':
        constraint_fn = lambda p: EmbedDist(tokenizer, text_encoder, constraint)(ori_prompt, p)
    elif constraint == 'bert_score_gpt2':
        ref_tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True)
        ref_tokenizer.model_max_length = 512
        ref_tokenizer.padding_side = "right"
        ref_tokenizer.pad_token = tokenizer.eos_token
        ref_model = AutoModelForCausalLM.from_pretrained('gpt2', output_hidden_states=True).to(device)
        constraint_fn = lambda p: EmbedDist(ref_tokenizer, ref_model, constraint)(ori_prompt, p)
    elif constraint == 'bert_score_gpt2_large':
        ref_tokenizer = AutoTokenizer.from_pretrained('gpt2-large', use_fast=True)
        ref_tokenizer.model_max_length = 512
        ref_tokenizer.padding_side = "right"
        ref_tokenizer.pad_token = tokenizer.eos_token
        ref_model = AutoModelForCausalLM.from_pretrained('gpt2-large', output_hidden_states=True).to(device)
        constraint_fn = lambda p: EmbedDist(ref_tokenizer, ref_model, constraint)(ori_prompt, p)
    
    return constraint_fn


class TempScheduler():
    def __init__(self, min_temp, max_temp, max_iter):
        self.max_iter = max_iter
        self.min_temp = min_temp
        self.max_temp = max_temp

        self.cur_iter = 0
        self.cur_temp = self.max_temp

    def step(self):
        self.cur_iter += 1

        self.cur_temp = self.max_temp + (self.min_temp - self.max_temp) * self.cur_iter / self.max_iter
        return self.cur_temp

    def get_temp(self):
        return self.cur_temp


def get_t_schedule(args):
    num_iters = args.num_iters
    name = args.t[:args.t.find('-')]

    if num_iters == 0:  ## ES only
        sche = []
        avg_t = 0
    elif name == 'rand':  ## rand-15-25
        min_t, max_t = [int(t) for t in args.t.split('-')[1:]]
        avg_t = (max_t + min_t) // 2
        sche = [np.random.randint(min_t, max_t) for i in range(num_iters)]
    elif name == 'fix':  ## fix-25
        fix_t = int(args.t.split('-')[1])
        avg_t = fix_t
        sche = [fix_t] * num_iters
    elif name == 'step':  ## step-10-5
        start_t, repeat = [int(t) for t in args.t.split('-')[1:]]
        sche = [start_t + (i // repeat) for i in range(num_iters)]
        avg_t = (sche[0] + sche[-1]) // 2

    return sche, avg_t
