from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import ray
from torch.nn import CrossEntropyLoss
import numpy as np
from typing import List, Union
import random
from ..model_utils import get_template, load_model_and_tokenizer
# https://huggingface.co/docs/accelerate/v0.11.0/en/memory#accelerate.find_executable_batch_size
from accelerate.utils import find_executable_batch_size
import transformers
import gc

@ray.remote
class GCGModelWorker:
    def __init__(self, 
                 model_name_or_path,
                 allow_non_ascii=False, 
                 return_tensors="pt",
                 seed=0,
                 num_gpus=1,
                 search_width=512,
                 use_prefix_cache=True,
                 **model_kwargs):
        random.seed(seed)

        self.search_width = search_width
        # starting search_batch_size, will auto reduce batch_size later if go OOM (distinct for each worker, but not each behavior for simplicity)
        self.search_batch_size = search_width

        self.use_prefix_cache = use_prefix_cache
        
        self.model_name = model_name_or_path.split("/")[-1]
        self.model_name_or_path = model_name_or_path
        
        self.num_gpus = num_gpus
        self.model, self.tokenizer = load_model_and_tokenizer(model_name_or_path, **model_kwargs)
        self.device = self.model.device
        self.return_tensors = return_tensors

        self.embed_layer = self.model.get_input_embeddings()
        self.vocab_size = self.embed_layer.weight.shape[0]  # can be larger than tokenizer.vocab_size for some models
        self.vocab_embeds = self.embed_layer(torch.arange(0, self.vocab_size).long().to(self.device))

        # Init instruction template, embeds and cache
        template = get_template(self.model_name_or_path, **model_kwargs)['prompt'].split("{instruction}")
        self.before_tc, self.after_tc = template[0], template[1]
        before_ids = self.tokenizer(self.before_tc, return_tensors=self.return_tensors, add_special_tokens=True).to(self.device)['input_ids']
        self.before_embeds = self.embed_layer(before_ids)
        after_ids = self.tokenizer(self.after_tc, return_tensors=self.return_tensors, add_special_tokens=False).to(self.device)['input_ids']
        self.after_embeds = self.embed_layer(after_ids)

        if self.use_prefix_cache:
            with torch.no_grad():
                outputs = self.model(inputs_embeds=self.before_embeds, use_cache=True)
                self.prefix_cache = outputs.past_key_values
        
        self.behaviors, self.targets = None, None
        self.all_target_ids = []
        self.all_target_embeds = []
        self.all_behavior_embeds = []

    def init_behaviors_and_targets(self, behaviors: List[str], targets: List[str]):
        # NOTE: self.targets is ONLY used here, self.behaviors only used here and in generate
        self.behaviors, self.targets = behaviors, targets

        for behavior, target in zip(behaviors, targets):
            behavior_ids = self.tokenizer(behavior, return_tensors=self.return_tensors, add_special_tokens=False).to(self.device)['input_ids']
            target_ids = self.tokenizer(target, return_tensors=self.return_tensors, add_special_tokens=False).to(self.device)['input_ids']
            behavior_embeds, target_embeds = self.embed_layer(behavior_ids), self.embed_layer(target_ids)

            # Move to CPUs here as number of behaviors can be large thus go OOM
            self.all_target_ids.append(target_ids.cpu())
            self.all_target_embeds.append(target_embeds.cpu())
            self.all_behavior_embeds.append(behavior_embeds.cpu())
    
    def get_token_gradient(self, worker_optim_ids, num_optim_tokens, selected_behavior_indices):
        """
        Calculate the token gradient for the selected behavior indices, average them, and normalize them for ensembling multiple models
        """
        token_grads = []
        for selected_behavior_idx in selected_behavior_indices:
            token_grads.append(self.compute_token_gradient(worker_optim_ids, num_optim_tokens, selected_behavior_idx))
        token_grad = sum(token_grads) / len(token_grads)
        return token_grad / token_grad.norm(dim=-1, keepdim=True)
    
    def compute_token_gradient(self, worker_optim_ids, num_optim_tokens, selected_behavior_idx):
        behavior_embeds = self.all_behavior_embeds[selected_behavior_idx].to(self.device)
        target_ids = self.all_target_ids[selected_behavior_idx].to(self.device)
        target_embeds = self.all_target_embeds[selected_behavior_idx].to(self.device)
        after_embeds = self.after_embeds
        before_embeds = self.before_embeds
        
        optim_ids_onehot = torch.zeros((1, num_optim_tokens, self.vocab_size), device=self.device).to(behavior_embeds.dtype)
        optim_ids_onehot.scatter_(2, worker_optim_ids.to(self.device).unsqueeze(2), 1.0).requires_grad_()
        optim_embeds = torch.matmul(optim_ids_onehot.squeeze(0), self.vocab_embeds).unsqueeze(0)
        
        if self.use_prefix_cache:
            prefix_cache = self.prefix_cache
            input_embeds = torch.cat([behavior_embeds, optim_embeds, after_embeds, target_embeds], dim=1)
            outputs = self.model(inputs_embeds=input_embeds, past_key_values=prefix_cache)
        else:
            input_embeds = torch.cat([before_embeds, behavior_embeds, optim_embeds, after_embeds, target_embeds], dim=1)
            outputs = self.model(inputs_embeds=input_embeds)

        logits = outputs.logits
        tmp = input_embeds.shape[1] - target_embeds.shape[1]
        shift_logits = logits[..., tmp-1:-1, :].contiguous()
        shift_labels = target_ids
        loss_fct = CrossEntropyLoss()

        ## GCG Logit Loss ##
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        token_grad = torch.autograd.grad(outputs=[loss], inputs=[optim_ids_onehot])[0]

        return token_grad

    def compute_loss_on_candidates_batch(self, sampled_top_indices, selected_behavior_indices):
        all_losses = []
        for behavior_idx in selected_behavior_indices:
            all_losses.append(self.compute_loss_on_candidates(sampled_top_indices, behavior_idx))
        behavior_losses = torch.stack(all_losses, dim=0)
        return behavior_losses
        
    def compute_loss_on_candidates(self, sampled_top_indices, selected_behavior_idx):
        behavior_embeds = self.all_behavior_embeds[selected_behavior_idx].to(self.device)
        target_ids = self.all_target_ids[selected_behavior_idx].to(self.device)
        target_embeds = self.all_target_embeds[selected_behavior_idx].to(self.device)
        search_width = sampled_top_indices.shape[0]

        sampled_top_embeds = self.embed_layer(sampled_top_indices)
        if self.use_prefix_cache:
            input_embeds = torch.cat([behavior_embeds.repeat(search_width, 1, 1),
                                    sampled_top_embeds,
                                    self.after_embeds.repeat(search_width, 1, 1),
                                    target_embeds.repeat(search_width, 1, 1)], dim=1)
        else:
            input_embeds = torch.cat([self.before_embeds.repeat(search_width, 1, 1),
                                    behavior_embeds.repeat(search_width, 1, 1),
                                    sampled_top_embeds,
                                    self.after_embeds.repeat(search_width, 1, 1),
                                    target_embeds.repeat(search_width, 1, 1)], dim=1)
        
        logits = find_executable_batch_size(self.compute_candidates_logits, self.search_batch_size)(input_embeds)

        tmp = input_embeds.shape[1] - target_embeds.shape[1]
        shift_logits = logits[..., tmp-1:-1, :].contiguous()
        shift_labels = target_ids.repeat(search_width, 1)
        loss_fct = CrossEntropyLoss(reduction='none')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return loss.view(search_width, -1).mean(dim=1)
    
    def compute_candidates_logits(self, search_batch_size, input_embeds):
        if self.search_batch_size != search_batch_size:
            print(f"INFO: Setting candidates search_batch_size to {search_batch_size})")
            self.search_batch_size = search_batch_size
            gc.collect() 
            torch.cuda.empty_cache() 

        logits = []
        for i in range(0, input_embeds.shape[0], search_batch_size):
            with torch.no_grad():
                input_embeds_batch = input_embeds[i:i+search_batch_size]

                if self.use_prefix_cache:
                    prefix_cache = self.prefix_cache
                    current_batch_size = input_embeds_batch.shape[0]
                    prefix_cache_batch = []
                    for i in range(len(prefix_cache)):
                        prefix_cache_batch.append([])
                        for j in range(len(prefix_cache[i])):
                            prefix_cache_batch[i].append(prefix_cache[i][j].expand(current_batch_size, -1, -1, -1))
                    outputs = self.model(inputs_embeds=input_embeds_batch, past_key_values=prefix_cache_batch)
                else:
                    outputs = self.model(inputs_embeds=input_embeds_batch)
            logits.append(outputs.logits)
        logits = torch.cat(logits, dim=0)
        return logits

    def generate(self, optim_str, selected_behavior_idx):
        inputs = self.before_tc + self.behaviors[selected_behavior_idx] + optim_str + self.after_tc
        encoded = self.tokenizer(inputs, return_tensors=self.return_tensors)

        with torch.no_grad():
            outputs = self.model.generate(**encoded.to(self.device), do_sample=False, max_new_tokens=16, use_cache=True)
            outputs = self.tokenizer.batch_decode(outputs)

        ouputs = outputs[0].replace(inputs, "")
        # for i, o in zip(inputs, outputs):
        #     print("\nbehavior:", self.behaviors[selected_behavior_idx] + optim_str[0])
        #     print("\noutput:", o.replace(i, ""))
        return ouputs

    def is_initialized(self):
        return f"==cuda:{ray.get_gpu_ids()}. {self.model_name_or_path}"
    
    def get_attribute_(self, attr_name):
        return getattr(self, attr_name, None)

#### Helper
def sample_control(control_toks, grad, search_width, topk=256, temp=1, not_allowed_tokens=None):
    if not_allowed_tokens is not None:
        # grad[:, not_allowed_tokens.to(grad.device)] = np.infty
        grad = grad.clone()
        grad[:, not_allowed_tokens.to(grad.device)] = grad.max() + 1

    top_indices = (-grad).topk(topk, dim=1).indices
    control_toks = control_toks.to(grad.device)

    original_control_toks = control_toks.repeat(search_width, 1)
    new_token_pos = torch.arange(
        0, 
        len(control_toks), 
        len(control_toks) / search_width,
        device=grad.device
    ).type(torch.int64)
    new_token_val = torch.gather(
        top_indices[new_token_pos], 1, 
        torch.randint(0, topk, (search_width, 1),
        device=grad.device)
    )
    new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)

    return new_control_toks


def get_nonascii_toks(tokenizer, device='cpu'):
    def is_ascii(s):
        return s.isascii() and s.isprintable()

    ascii_toks = []
    for i in range(3, tokenizer.vocab_size):
        if not is_ascii(tokenizer.decode([i])):
            ascii_toks.append(i)
    
    if tokenizer.bos_token_id is not None:
        ascii_toks.append(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        ascii_toks.append(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        ascii_toks.append(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        ascii_toks.append(tokenizer.unk_token_id)

    if "Baichuan2" in tokenizer.name_or_path:
        ascii_toks += [i for i in range(101, 1000)]
    
    return torch.tensor(ascii_toks, device=device)