import gc

import numpy as np
import torch
import torch.nn as nn
from peft import PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GPT2LMHeadModel,
    GPTJForCausalLM,
    GPTNeoXForCausalLM,
    LlamaForCausalLM,
    MistralForCausalLM,
    GemmaForCausalLM,
)


def get_embedding_layer(model):
    if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
        return model.transformer.wte
    elif isinstance(model, LlamaForCausalLM):
        return model.model.embed_tokens
    elif isinstance(model, GPTNeoXForCausalLM):
        return model.base_model.embed_in
    elif isinstance(model, MistralForCausalLM):
        return model.model.embed_tokens
    elif isinstance(model, GemmaForCausalLM):
        return model.model.embed_tokens
    elif model.name_or_path == "apple/OpenELM-3B-Instruct":
        return model.transformer.token_embeddings
    else:
        raise ValueError(f"Unknown model type: {type(model)}")


def get_embedding_matrix(model):
    if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
        return model.transformer.wte.weight
    elif isinstance(model, LlamaForCausalLM):
        return model.model.embed_tokens.weight
    elif isinstance(model, GPTNeoXForCausalLM):
        return model.base_model.embed_in.weight
    elif isinstance(model, MistralForCausalLM):
        return model.model.embed_tokens.weight
    elif isinstance(model, GemmaForCausalLM):
        return model.model.embed_tokens.weight
    elif model.name_or_path == "apple/OpenELM-3B-Instruct":
        return model.transformer.token_embeddings.weight
    else:
        raise ValueError(f"Unknown model type: {type(model)}")


def get_embeddings(model, input_ids):
    if isinstance(model, GPTJForCausalLM) or isinstance(model, GPT2LMHeadModel):
        return model.transformer.wte(input_ids).half()
    elif isinstance(model, LlamaForCausalLM):
        return model.model.embed_tokens(input_ids)
    elif isinstance(model, GPTNeoXForCausalLM):
        return model.base_model.embed_in(input_ids).half()
    elif isinstance(model, MistralForCausalLM):
        return model.model.embed_tokens(input_ids)
    elif isinstance(model, GemmaForCausalLM):
        return model.model.embed_tokens(input_ids)
    elif model.name_or_path == "apple/OpenELM-3B-Instruct":
        return model.transformer.token_embeddings(input_ids)
    else:
        raise ValueError(f"Unknown model type: {type(model)}")


def token_gradients(model, input_ids, input_slice, target_slice, loss_slice):
    """
    Computes gradients of the loss with respect to the coordinates.

    Parameters
    ----------
    model : Transformer Model
        The transformer model to be used.
    input_ids : torch.Tensor
        The input sequence in the form of token ids.
    input_slice : slice
        The slice of the input sequence for which gradients need to be computed.
    target_slice : slice
        The slice of the input sequence to be used as targets.
    loss_slice : slice
        The slice of the logits to be used for computing the loss.

    Returns
    -------
    torch.Tensor
        The gradients of each token in the input_slice with respect to the loss.
    """

    embed_weights = get_embedding_matrix(model)
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embed_weights.shape[0],
        device=model.device,
        dtype=embed_weights.dtype,
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embed_weights.dtype),
    )
    one_hot.requires_grad_()
    input_embeds = (one_hot @ embed_weights).unsqueeze(0)

    # now stitch it together with the rest of the embeddings
    embeds = get_embeddings(model, input_ids.unsqueeze(0)).detach()
    full_embeds = torch.cat(
        [
            embeds[:, : input_slice.start, :],
            input_embeds,
            embeds[:, input_slice.stop :, :],
        ],
        dim=1,
    )

    logits = model(inputs_embeds=full_embeds).logits
    targets = input_ids[target_slice]
    loss = nn.CrossEntropyLoss()(logits[0, loss_slice, :], targets)

    loss.backward()

    grad = one_hot.grad.clone()
    grad = grad / grad.norm(dim=-1, keepdim=True)

    return grad


def sample_control(
    control_toks,
    grad,
    batch_size,
    topk=256,
    temp=1,
    not_allowed_tokens=None,
    verbose: bool = False,
):

    if not_allowed_tokens is not None:
        grad[:, not_allowed_tokens.to(grad.device)] = np.infty

    top_indices = (-grad).topk(topk, dim=1).indices
    control_toks = control_toks.to(grad.device)

    original_control_toks = control_toks.repeat(batch_size, 1)

    # Original position selection
    new_token_pos = torch.arange(
        0, len(control_toks), len(control_toks) / batch_size, device=grad.device
    ).type(torch.int64)

    # print("Orignal position selection strategy: ", new_token_pos)

    # Testing out random positions:
    # new_token_pos = torch.randint(0, len(control_toks), (batch_size,), device=grad.device).type(torch.int64)

    # print("Random position selection strategy: ", new_token_pos)

    new_token_val = torch.gather(
        top_indices[new_token_pos],
        1,
        torch.randint(0, topk, (batch_size, 1), device=grad.device),
    )
    new_control_toks = original_control_toks.scatter_(
        1, new_token_pos.unsqueeze(-1), new_token_val
    )

    if verbose:
        print("\ncontrol_toks: ", len(control_toks))
        print("\n top_indices: ")
        print(top_indices.shape)
        print("\n original_control_toks: ")
        print(original_control_toks.shape)
        print("\n new_token_pos: ")
        # print(new_token_pos)
        print(new_token_pos.shape)
        print("\ntop_indices[new_token_pos]")
        print(top_indices[new_token_pos].shape)
        print("\ntorch.randint(0, topk, (batch_size, 1), device=grad.device),")
        print(torch.randint(0, topk, (batch_size, 1), device=grad.device).shape)
        # print(torch.randint(0, topk, (batch_size, 1), device=grad.device))
        print("\n new_token_val: ")
        print(new_token_val.shape)
        # print(new_token_val)
        print("\n new_control_toks: ")
        print(new_control_toks.shape)
        # print(new_control_toks)

    return new_control_toks


def multi_cord_sample_control(
    control_toks,
    grad,
    batch_size,
    topk=256,
    temp=1,
    num_coordinates=8,
    not_allowed_tokens=None,
    verbose: bool = False,
):

    assert num_coordinates <= len(
        control_toks
    ), f"Set the number of coordinates to less than length of control tokens: {len(control_toks)}"

    if not_allowed_tokens is not None:
        grad[:, not_allowed_tokens.to(grad.device)] = np.infty

    top_indices = (-grad).float().topk(topk, dim=1).indices
    control_toks = control_toks.to(grad.device)

    original_control_toks = control_toks.repeat(batch_size, 1)

    # Generate random positions for each row in the batch. torch.randperm prevents choosing same positions in a given row.
    all_indices = torch.stack(
        [
            torch.randperm(len(control_toks), device=grad.device)
            for _ in range(batch_size)
        ]
    )
    multi_token_pos = all_indices[:, :num_coordinates]

    # print("multi_token_pos: ", multi_token_pos)

    for cord in range(num_coordinates):

        new_token_pos = multi_token_pos[:, cord]

        new_token_val = torch.gather(
            top_indices[new_token_pos],
            1,
            torch.randint(0, topk, (batch_size, 1), device=grad.device),
        )

        new_control_toks = original_control_toks.scatter_(
            1, new_token_pos.unsqueeze(-1), new_token_val
        )

    return new_control_toks


def get_filtered_cands(
    tokenizer,
    control_cand,
    max_toks,
    filter_cand=True,
    curr_control=None,
    verbose: bool = False,
):
    cands, count = [], 0
    for i in range(control_cand.shape[0]):
        decoded_str = tokenizer.decode(control_cand[i], skip_special_tokens=False)
        # print("current_suffix:", curr_control)
        if filter_cand:
            if decoded_str != curr_control and len(
                tokenizer(decoded_str, add_special_tokens=False).input_ids
            ) == len(control_cand[i]):
                cands.append(decoded_str)
            else:
                count += 1
        else:
            cands.append(decoded_str)

    if filter_cand:
        # print("Current len of control candidates: ", len(cands))
        # Replicates the last one to fill up the batch size.
        cands = cands + [cands[-1]] * (len(control_cand) - len(cands))
        if verbose:
            print(
                f"\nWarning: {round(count / len(control_cand), 2)} control candidates were not valid"
            )

    # print("candidates: ", len(cands))

    return cands


def get_logits(
    *,
    model,
    tokenizer,
    input_ids,
    control_slice,
    test_controls=None,
    return_ids=False,
    batch_size=512,
):

    if isinstance(test_controls[0], str):
        max_len = control_slice.stop - control_slice.start
        test_ids = [
            torch.tensor(
                # tokenizer(control, add_special_tokens=False).input_ids[:max_len],
                tokenizer(
                    control,
                    add_special_tokens=False,
                    max_length=max_len,
                    truncation=True,
                    padding="max_length",
                ).input_ids,
                # device=model.device,
            )
            for control in test_controls
        ]

        # I don't think padding is needed here again for our case, given we already have added padding to our strings before.
        pad_tok = 0
        while pad_tok in input_ids or any([pad_tok in ids for ids in test_ids]):
            pad_tok += 1
        nested_ids = torch.nested.nested_tensor(test_ids)
        test_ids = torch.nested.to_padded_tensor(
            nested_ids, pad_tok, (len(test_ids), max_len)
        )
    else:
        raise ValueError(
            f"test_controls must be a list of strings, got {type(test_controls)}"
        )

    if not (test_ids[0].shape[0] == control_slice.stop - control_slice.start):
        raise ValueError(
            (
                f"test_controls must have shape "
                f"(n, {control_slice.stop - control_slice.start}), "
                f"got {test_ids.shape}"
            )
        )

    locs = (
        torch.arange(control_slice.start, control_slice.stop).repeat(
            test_ids.shape[0], 1
        )
        # .to(model.device)
    )
    ids = torch.scatter(
        input_ids.unsqueeze(0).repeat(test_ids.shape[0], 1),  # .to(model.device),
        1,
        locs,
        test_ids,
    )

    if pad_tok >= 0:
        attn_mask = (ids != pad_tok).type(ids.dtype)
    else:
        attn_mask = None

    if return_ids:
        return (
            forward(
                model=model,
                input_ids=ids,
                attention_mask=attn_mask,
                batch_size=batch_size,
            ),
            ids,
        )
    else:
        return forward(
            model=model, input_ids=ids, attention_mask=attn_mask, batch_size=batch_size
        )


def forward(*, model, input_ids, attention_mask, batch_size=64):

    logits = []
    for i in range(0, input_ids.shape[0], batch_size):

        batch_input_ids = input_ids[i : i + batch_size]
        if attention_mask is not None:
            batch_attention_mask = attention_mask[i : i + batch_size]
        else:
            batch_attention_mask = None

        logits.append(
            model(
                input_ids=batch_input_ids.to(model.device),
                attention_mask=batch_attention_mask.to(model.device),
            ).logits.cpu()
        )

        del batch_input_ids, batch_attention_mask
        gc.collect()
        torch.cuda.empty_cache()

    return torch.cat(logits, dim=0)


def target_loss(logits, ids, target_slice):
    crit = nn.CrossEntropyLoss(reduction="none")
    loss_slice = slice(target_slice.start - 1, target_slice.stop - 1)
    loss = crit(logits[:, loss_slice, :].transpose(1, 2), ids[:, target_slice])

    # print("loss: ", loss.shape)
    # print("mean_loss: ", loss.mean(dim=-1).shape)
    return loss.mean(dim=-1)


def load_model_and_tokenizer(
    model_path,
    tokenizer_path=None,
    device="cuda:0",
    quant8: bool = False,
    cache_dir="/mnt/data/prj_rag",
    **kwargs,
):
    if quant8:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            load_in_8bit=True,
            # device_map="auto",  # Dual GPU by default
            device_map=device,
            trust_remote_code=True,
            cache_dir=cache_dir,
        ).eval()

    elif "guanaco" in model_path:
        model = AutoModelForCausalLM.from_pretrained(
            "huggyllama/llama-7b",
            torch_dtype=torch.bfloat16,
            device_map="auto",
            max_memory={i: "24000MB" for i in range(torch.cuda.device_count())},
            cache_dir=cache_dir,
        )
        model = PeftModel.from_pretrained(model, model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path)

    # OpenELM cannot be loaded in float 16
    elif "OpenELM" in model_path:
        model = (
            AutoModelForCausalLM.from_pretrained(
                model_path,
                trust_remote_code=True,
                cache_dir=cache_dir,
                **kwargs,
            )
            .to(device)
            .eval()
        )
        model.generation_config.temperature = None
        model.generation_config.top_p = None

    else:
        model = (
            AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16,
                trust_remote_code=True,
                cache_dir=cache_dir,
                **kwargs,
            )
            .to(device)
            .eval()
        )
        model.generation_config.temperature = None
        model.generation_config.top_p = None

    tokenizer_path = model_path if tokenizer_path is None else tokenizer_path

    # OpenELM uses Llama tokenizer
    if "OpenELM" in model_path:
        tokenizer_path = "meta-llama/Llama-2-7b-hf"

    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path, trust_remote_code=True, use_fast=False
    )

    if "oasst-sft-6-llama-30b" in tokenizer_path:
        tokenizer.bos_token_id = 1
        tokenizer.unk_token_id = 0
    if "guanaco" in tokenizer_path:
        tokenizer.eos_token_id = 2
        tokenizer.unk_token_id = 0
    if "llama-2" in tokenizer_path:
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.padding_side = "left"
    if "Llama-2" in tokenizer_path:
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.padding_side = "left"
    if "llama3" in tokenizer_path:
        tokenizer.pad_token = tokenizer.unk_token
        tokenizer.padding_side = "left"
    if "falcon" in tokenizer_path:
        tokenizer.padding_side = "left"
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer
