import gc

import numpy as np
import torch
import torch.nn as nn
from tqdm.auto import tqdm

from llm_attacks import AttackPrompt, MultiPromptAttack, PromptManager
from llm_attacks import get_embedding_matrix, get_embeddings


def token_gradients(model, input_ids, input_slice, target_slice, loss_slice, use_pez, n_epoch):

    """
    Computes gradients of the loss with respect to the adv embeddings.
    
    Parameters
    ----------
    model : Transformer Model
        The transformer model to be used.
    input_ids : torch.Tensor
        The input sequence in the form of token ids.
    input_slice : slice
        The slice of the input sequence for which gradients need to be computed.
    target_slice : slice
        The slice of the input sequence to be used as targets.
    loss_slice : slice
        The slice of the logits to be used for computing the loss.

    Returns
    -------
    torch.Tensor
        The gradients on adv embeddings.
    """
    if not use_pez:
        embed_weights = get_embedding_matrix(model)
        one_hot = torch.zeros(
            input_ids[input_slice].shape[0],
            embed_weights.shape[0],
            device=model.device,
            dtype=embed_weights.dtype
        )
        
        one_hot.scatter_(
            1, 
            input_ids[input_slice].unsqueeze(1),
            torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embed_weights.dtype)
        )
        one_hot.requires_grad_()
        input_embeds = (one_hot @ embed_weights).unsqueeze(0)
        
        embeds = get_embeddings(model, input_ids.unsqueeze(0)).detach()
        full_embeds = torch.cat(
            [
                embeds[:,:input_slice.start,:], 
                input_embeds, 
                embeds[:,input_slice.stop:,:]
            ], 
            dim=1)
        
        logits = model(inputs_embeds=full_embeds).logits
        targets = input_ids[target_slice]
        loss = nn.CrossEntropyLoss()(logits[0,loss_slice,:], targets)
        
        loss.backward()
        
        return one_hot.grad.clone()
    else:
        with torch.no_grad():
            full_embeds = get_embeddings(model, input_ids.unsqueeze(0)).detach()
        
        full_embeds = full_embeds.clone().detach().requires_grad_(False)
        
        control_embeds = full_embeds[:, input_slice.start:input_slice.stop, :]
        control_embeds = control_embeds.clone().detach().requires_grad_(True)
        
        full_embeds[:, input_slice.start:input_slice.stop, :] = control_embeds
        
        logits = model(inputs_embeds=full_embeds).logits
        targets = input_ids[target_slice]  
        favors_starting = True
        if not favors_starting:
            losses = nn.CrossEntropyLoss()(logits[0, loss_slice, :], targets)
            losses.backward()
        else:
            losses = nn.CrossEntropyLoss(reduction='none')(logits[0, loss_slice, :], targets)
            decay_rate = 0.9
            position_indices = torch.arange(len(losses)).float()
            weights = decay_rate ** position_indices
            weighted_loss = (losses * weights.to(losses.device)).sum()
            
            losses_perplexity = nn.CrossEntropyLoss(reduction='none')(
                logits[0, input_slice, :],
                input_ids[input_slice.start + 1 : input_slice.stop + 1] 
            )

            avg_nll = losses_perplexity.mean()
            perplexity = avg_nll
            
            with torch.no_grad():
                scale_factor = weighted_loss.abs() / (perplexity.abs() + 1e-8)
            scaled_perplexity = perplexity * scale_factor

            lambda_ = min(0.3 + (n_epoch / 100), 0.3)
            total_loss = weighted_loss + lambda_ * scaled_perplexity
            weighted_loss.backward()
        
            del weights, losses 
            gc.collect()
        
        gradients = control_embeds.grad.clone()
        
        return gradients


class MarageAttackPrompt(AttackPrompt):

    def __init__(self, *args, **kwargs):
        
        super().__init__(*args, **kwargs)
    
    def grad(self, model):
        model, n_epoch = model
        return token_gradients(
            model, 
            self.input_ids.to(model.device), 
            self._control_slice, 
            self._target_slice, 
            self._loss_slice,
            self.use_pez,
            n_epoch
        )

class MaragePromptManager(PromptManager):

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)
    def sample_control(self, grad, batch_size, topk=256, temp=1, allow_non_ascii=False):
        if not allow_non_ascii:
            grad[:, self._nonascii_toks.to(grad.device)] = np.inf
        
        top_indices = (-grad).topk(topk, dim=1).indices
        control_toks = self.control_toks.to(grad.device)
        original_control_toks = control_toks.repeat(batch_size, 1)
        new_token_pos = torch.arange(
            0, 
            len(control_toks), 
            len(control_toks) / batch_size,
            device=grad.device
        ).type(torch.int64)
        new_token_val = torch.gather(
            top_indices[new_token_pos], 1, 
            torch.randint(0, topk, (batch_size, 1),
            device=grad.device)
        )
        new_control_toks = original_control_toks.scatter_(1, new_token_pos.unsqueeze(-1), new_token_val)
        return new_control_toks

class ProjectionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim, bias=False)
        
    def forward(self, x):
        return self.proj(x)


# two layer mlp
# class ProjectionModel(nn.Module):
#     def __init__(self, input_dim, output_dim, hidden_dim=4096):
#         super().__init__()
#         self.fc1 = nn.Linear(input_dim, hidden_dim, bias=True)
#         self.activation = nn.ReLU()
#         self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
        
#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.activation(x)
#         x = self.fc2(x)
#         return x
    
class MarageMultiPromptAttack(MultiPromptAttack):

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)

    def step(self, 
             batch_size=1024, 
             topk=256,
             temp=1,
             allow_non_ascii=False, 
             target_weight=1, 
             control_weight=0.1, 
             verbose=False, 
             n_epoch = 1,
             opt_only=False,
             filter_cand=True):

        if not self.use_pez:
            opt_only = False

            main_device = self.models[0].device
            control_cands = []

            for j, worker in enumerate(self.workers):
                worker(self.prompts[j], "grad", (worker.model, n_epoch))

            grad = None
            for j, worker in enumerate(self.workers):
                new_grad = worker.results.get().to(main_device)
                new_grad = new_grad / new_grad.norm(dim=-1, keepdim=True)
                if grad is None:
                    grad = torch.zeros_like(new_grad)
                if grad.shape != new_grad.shape:
                    with torch.no_grad():
                        control_cand = self.prompts[j-1].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
                        control_cands.append(self.get_filtered_cands(j-1, control_cand, filter_cand=filter_cand, curr_control=self.control_str))
                    grad = new_grad
                else:
                    grad += new_grad

            with torch.no_grad():
                control_cand = self.prompts[j].sample_control(grad, batch_size, topk, temp, allow_non_ascii)
                control_cands.append(self.get_filtered_cands(j, control_cand, filter_cand=filter_cand, curr_control=self.control_str))
            del grad, control_cand ; gc.collect()
            
            loss = torch.zeros(len(control_cands) * batch_size).to(main_device)
            with torch.no_grad():
                for j, cand in enumerate(control_cands):
                    progress = tqdm(range(len(self.prompts[0])), total=len(self.prompts[0])) if verbose else enumerate(self.prompts[0])
                    for i in progress:
                        for k, worker in enumerate(self.workers):
                            worker(self.prompts[k][i], "logits", worker.model, cand, return_ids=True)
                        logits, ids = zip(*[worker.results.get() for worker in self.workers])
                        loss[j*batch_size:(j+1)*batch_size] += sum([
                            target_weight*self.prompts[k][i].target_loss(logit, id).mean(dim=-1).to(main_device) 
                            for k, (logit, id) in enumerate(zip(logits, ids))
                        ])
                        if control_weight != 0:
                            loss[j*batch_size:(j+1)*batch_size] += sum([
                                control_weight*self.prompts[k][i].control_loss(logit, id).mean(dim=-1).to(main_device)
                                for k, (logit, id) in enumerate(zip(logits, ids))
                            ])
                        del logits, ids ; gc.collect()
                        
                        if verbose:
                            progress.set_description(f"loss={loss[j*batch_size:(j+1)*batch_size].min().item()/(i+1):.4f}")

                min_idx = loss.argmin()
                model_idx = min_idx // batch_size
                batch_idx = min_idx % batch_size
                next_control, cand_loss = control_cands[model_idx][batch_idx], loss[min_idx]
            
            del control_cands, loss ; gc.collect()

            print('Current length:', len(self.workers[0].tokenizer(next_control).input_ids[1:]))
            print(next_control)

            return next_control, cand_loss.item() / len(self.prompts[0]) / len(self.workers)
        else:
            opt_only = False

            main_device = self.models[0].device
            
            updated_embeddings = []
            for j, worker in enumerate(self.workers):
                worker(self.prompts[j], "grad", (worker.model, n_epoch))
            
            aggregated_grad = None
            use_gradnorm = False
            target_norm = 1
            grad_norms = []
            grads = []
            for j, worker in enumerate(self.workers):
                new_grad = worker.results.get().to(main_device)
                
                grad_norm = new_grad.norm().item()
                grad_norms.append(grad_norm)
                grads.append(new_grad)
                
                if aggregated_grad is not None and aggregated_grad.shape[2] != worker.model.config.hidden_size:
                    proj_model = ProjectionModel(worker.model.config.hidden_size, aggregated_grad.shape[2]).to(worker.model.device)
                    model_path = ""
                    proj_model.load_state_dict(torch.load(model_path))
                    proj_model = proj_model.to(new_grad.dtype)
                    proj_model = proj_model.to(worker.model.device)
                    
                    with torch.no_grad():
                        new_grad = proj_model(new_grad.to(worker.model.device)).to(main_device)
                    del proj_model
                    gc.collect()
                    
                new_grad = new_grad / new_grad.norm(dim=-1, keepdim=True)
            
                if aggregated_grad is None:
                    aggregated_grad = torch.zeros_like(new_grad)
                if aggregated_grad.shape != new_grad.shape:
                    print("Shape is different*******************************\n\n\n")
                    aggregated_grad = new_grad
                else:
                    aggregated_grad += new_grad
            
            if use_gradnorm:
                if len(grads) == 2:
                    norm1, norm2 = grad_norms
                    
                    alpha = target_norm / (norm1 + norm2 + 1e-8) * norm2
                    beta = target_norm / (norm1 + norm2 + 1e-8) * norm1
                    
                    scaled_grad1 = alpha * grads[0]
                    scaled_grad2 = beta * grads[1]
                    
                    aggregated_grad = scaled_grad1 + scaled_grad2
                else:
                    total_norm = sum(grad_norms) + 1e-8
                    alphas = [target_norm / total_norm * norm for norm in grad_norms]
                    scaled_grads = [alpha * grad for alpha, grad in zip(alphas, grads)]
                    aggregated_grad = torch.sum(torch.stack(scaled_grads), dim=0)
                    
            self.optimizer.zero_grad()
            self.control_embed.grad = aggregated_grad.squeeze()
            self.optimizer.step()
            self.scheduler.step()
            updated_embeddings.append(self.control_embed.clone().detach())
                        
            loss = torch.zeros(len(updated_embeddings) * batch_size).to(main_device)
            with torch.no_grad():
                for j, embed in enumerate(updated_embeddings):
                    progress = tqdm(range(len(self.prompts[0])), total=len(self.prompts[0])) if verbose else enumerate(self.prompts[0])
                    for i in progress:
                        for k, worker in enumerate(self.workers):
                            worker(self.prompts[k][i], "logits", worker.model, embed, return_ids=True)
                    
                        logits, ids = zip(*[worker.results.get() for worker in self.workers])

                        for m, (logit, id_) in enumerate(zip(logits, ids)):
                            target_loss = target_weight * self.prompts[m][i].target_loss(logit, id_).mean(dim=-1).to(main_device)
                            control_loss = 0.0
                            if control_weight != 0:
                                control_loss = control_weight * self.prompts[m][i].control_loss(logit, id_).mean(dim=-1).to(main_device)
                            loss[j * batch_size:(j+1) * batch_size] += (target_loss + control_loss)
                            if verbose:
                                progress.set_description(
                                    f"loss={loss[j * batch_size:(j+1) * batch_size].item():.4f}"
                                )
                                
                        del logits, ids
                        gc.collect()
            
                min_idx = loss.argmin()
                
                model_idx = min_idx // batch_size
                batch_idx = min_idx % batch_size
                best_embedding = updated_embeddings[model_idx]
                best_loss = loss[min_idx].item()
                
            del updated_embeddings, loss
            gc.collect()
            return nn.Parameter(best_embedding), best_loss / len(self.prompts[0]) / len(self.workers)
