import os
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
from sentence_transformers.util import (semantic_search,
                                        dot_score,
                                        normalize_embeddings)
import numpy as np
import fire
from tqdm import tqdm
import json


def main(
    reward_model: str = "microsoft/deberta-v3-xsmall",
    reward_checkpoint_path: str =  "/data/private_models/proxy_models/deberta-v3-xsmall_hh_1e-05_32_2.pkl",
    gold_model: str = "microsoft/deberta-v3-large",
    gold_checkpoint_path_1: str = "/data/private_models/proxy_models/gold_ensemble/deberta_large_gold1.pkl",
    gold_checkpoint_path_2: str = "/data/private_models/proxy_models/gold_ensemble/deberta_large_gold2.pkl",
    input_file: str = "/data/private_models/misc/wb_data/pythia6b_inputs_1_train.json",
    N: int = 512,
    num_steps: int = 100,
    eval_steps: int = 5,
    num_optim_tokens: int = 4,
    lr: float = 0.5,
    batch_size: int = 32,
    pez_mode: str = "prepend",
    output_file: str = "",
    one_prompt: bool = False, # one PEZ prompt for the entire dataset
    verbose: bool = False,
    normalization: bool = False,
):
    assert output_file and "npy" in output_file, "output_file need to in .npy format"
    assert input_file and "json" in input_file, "input_file need to be a json file"

    if os.path.dirname(output_file): os.makedirs(os.path.dirname(output_file), exist_ok=True)   

    def nn_project(curr_embeds, embedding_layer):    
        B, seq_len, emb_dim = curr_embeds.shape
        
        # Using the sentence transformers semantic search which is 
        # a dot product exact kNN search between a set of 
        # query vectors and a corpus of vectors
        curr_embeds = curr_embeds.reshape((-1, emb_dim))
        curr_embeds = normalize_embeddings(curr_embeds) # queries

        embedding_matrix = embedding_layer.weight
        embedding_matrix = normalize_embeddings(embedding_matrix) # corpus
        
        hits = semantic_search(curr_embeds, embedding_matrix, 
                                query_chunk_size=curr_embeds.shape[0], 
                                top_k=3,
                                score_function=dot_score)

        nn_indices = torch.tensor([hit[0]["corpus_id"] for hit in hits], device='cuda').reshape((B, seq_len))
        projected_embeds = embedding_layer(nn_indices)

        return projected_embeds, nn_indices

    class project_soft_embeds(torch.autograd.Function):
        """
        This is a PyTorch layer that projects the soft embeddings to the nearest
        hard embedding in the forward pass and passes the gradient through in the
        backward pass. This is a straight-through estimator.
        """
        @staticmethod
        def forward(ctx, input):
            """
            In the forward pass we receive a Tensor containing the input and return
            a Tensor containing the output. ctx is a context object that can be used
            to stash information for backward computation. You can cache arbitrary
            objects for use in the backward pass using the ctx.save_for_backward method.
            """
            ctx.save_for_backward(input)
            # projected_embeds, nn_indices = nn_project(input, model.transformer.wte)  # for GPT-2
            # projected_embeds, nn_indices = nn_project(input, model.gpt_neox.embed_in)  # for Pythia
            projected_embeds, nn_indices = nn_project(input, rw_model.deberta.embeddings.word_embeddings) # for deberta
            return projected_embeds

        @staticmethod
        def backward(ctx, grad_output):
            """
            In the backward pass we receive a Tensor containing the gradient of the loss
            with respect to the output, and we need to compute the gradient of the loss
            with respect to the input.
            """
            input, = ctx.saved_tensors
            return grad_output  # straight-through estimator

    def load_model(model: str, load_path: str):
        config = AutoConfig.from_pretrained(model, num_labels=1)
        model = AutoModelForSequenceClassification.from_pretrained(model, config=config)
        model.load_state_dict(torch.load(load_path), strict=True)
        return model

    rw_tokenizer = AutoTokenizer.from_pretrained(reward_model)
    rw_tokenizer.truncation_side = "left"
    gold_tokenizer = AutoTokenizer.from_pretrained(gold_model)
    gold_tokenizer.truncation_side = "left"
    rw_model = load_model(reward_model, reward_checkpoint_path).cuda().eval()
    gold_model_1 = load_model(gold_model, gold_checkpoint_path_1).cuda().eval()
    
    gold_model_2 = None
    if gold_checkpoint_path_2:
        gold_model_2 = load_model(gold_model, gold_checkpoint_path_2).cuda().eval()
        for p in gold_model_2.parameters(): p.requires_grad = False


    for p in rw_model.parameters(): p.requires_grad = False
    for p in gold_model_1.parameters(): p.requires_grad = False

    metadata = json.load(open("../../reward/normalization_calibration/normalization_coeffs.json"))
    RM_STATS = metadata[reward_checkpoint_path]
    GM_STATS_1 = metadata[gold_checkpoint_path_1]

    if gold_model_2: GM_STATS_2 = metadata[gold_checkpoint_path_2]
    
    shape = rw_model.deberta.embeddings.word_embeddings.weight.mean(0).data.shape[0]

    def run_pez(prompts, completions, tokenizer, num_optim_tokens, num_steps, eval_steps, mode='append', maximize=True, optim_embeds=None):
        # Generate random tensor with the same shape as your embeddings
        average_embed_weights = torch.randn(shape, device='cuda')
        
        multiplier = -1 if maximize else 1
        # ========== orig prompts ========== #
        all_input_tokens = []
        attack_positions = []
        placeholder_tokens = torch.tensor([42 for _ in range(num_optim_tokens)])
        for idx, (prompt, completion) in enumerate(zip(prompts, completions)):
            prompt_tokens = tokenizer(prompt, return_tensors="pt").input_ids[0,:-1] # skip eos
            completion_tokens = tokenizer(completion, return_tensors="pt").input_ids[0,1:-1] # skip bos
            if mode == 'prepend':
                example_tokens = torch.cat([prompt_tokens, placeholder_tokens, completion_tokens])
                attack_positions.append([len(prompt_tokens) + offset for offset in range(num_optim_tokens)])
            elif mode == 'append':
                example_tokens = torch.cat([prompt_tokens, completion_tokens, placeholder_tokens])
                attack_positions.append([len(prompt_tokens) + len(completion_tokens) + offset for offset in range(num_optim_tokens)])
            elif mode == 'random_replace':
                if num_optim_tokens >= len(completion_tokens):
                    example_tokens = torch.cat([prompt_tokens, placeholder_tokens, completion_tokens])
                    attack_positions.append([len(prompt_tokens) + offset for offset in range(num_optim_tokens)])
                else:
                    sampled_idx = np.random.choice(len(completion_tokens), num_optim_tokens, replace=False)
                    example_tokens = torch.cat([prompt_tokens, completion_tokens])
                    attack_positions.append([len(prompt_tokens) + offset for offset in sampled_idx])
            example_tokens = torch.cat([example_tokens, torch.tensor([rw_tokenizer.eos_token_id])])
            all_input_tokens.append(example_tokens)
        
        all_input_tokens = pad_sequence(all_input_tokens, batch_first=True, padding_value=rw_tokenizer.pad_token_id).cuda()
        attention_mask = (all_input_tokens != rw_tokenizer.pad_token_id).cuda()
        all_input_embeds = rw_model.deberta.embeddings.word_embeddings(all_input_tokens).data
        attack_positions = torch.tensor(attack_positions, device='cuda')

        # ========== setup optim_embeds ========== #
        if optim_embeds is None:
            # optim_embeds = all_input_embeds[torch.arange(all_input_embeds.size(0)).unsqueeze(1), attack_positions].clone()
            optim_embeds = average_embed_weights.repeat(len(prompts), num_optim_tokens, 1).clone()
            optim_embeds = torch.nn.Parameter(optim_embeds)
            optim_embeds.requires_grad_()

        # ========== setup optimizer and scheduler ========== #
        optimizer = torch.optim.Adam([optim_embeds], lr=lr, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)

        # ========== run optimization ========== #
        min_loss = 1e10
        proxy_scores, gold_scores, perturbed_examples = [], [], []

        for step in range(num_steps):
            
            optimizer.zero_grad()
            
            # ========== compute logits with concatenated optim embeds and target text ========== #
            
            projected_optim_embeds = project_soft_embeds.apply(optim_embeds)
            input_embeds = all_input_embeds.clone()
            input_embeds[torch.arange(input_embeds.size(0)).unsqueeze(1), attack_positions] = projected_optim_embeds

            outputs = rw_model(inputs_embeds=input_embeds, attention_mask=attention_mask)
            logits = outputs.logits            
            loss = multiplier * logits
            loss = loss.mean()
            
            if loss < min_loss:
                min_loss = loss
                optim_embeds_best = optim_embeds.data.clone()
            if step == 0:
                orig_loss = loss
            if verbose:
                if step % 10 == 0:
                    print('{} {:.3f}'.format(step, logits.mean().item()))
            
            # ========== update optim_embeds ========== #
            loss_bp = loss
            loss_bp.backward()
            optimizer.step()
            scheduler.step()

            # Eval
            if step % eval_steps == 0:
                projected_embeds, nn_indices = nn_project(input_embeds.data, rw_model.deberta.embeddings.word_embeddings) # for deberta
                adv_prompts = gold_tokenizer.batch_decode(nn_indices, skip_special_tokens=True)
                perturbed_examples.append(adv_prompts)
                gold_inputs = gold_tokenizer(adv_prompts, padding=True, return_tensors="pt")
                gold_inputs = {k: v.cuda() for k, v in gold_inputs.items()}
                with torch.no_grad():
                    outputs = rw_model(inputs_embeds=input_embeds, attention_mask=attention_mask)
                    logits = outputs.logits.reshape((-1)).detach().cpu().numpy()
                    # if normalization:
                    logits = (logits - RM_STATS["mean"]) / RM_STATS["std"]
                    proxy_scores.append(logits)

                    gold_outputs = gold_model_1(**gold_inputs)
                    gold_logits_1 = gold_outputs.logits
                    # if normalization:
                    gold_logits_1 = (gold_logits_1 - GM_STATS_1["mean"]) / GM_STATS_1["std"]
                    gold_logits_1 = gold_logits_1.reshape((-1)).detach().cpu().numpy()

                    if gold_model_2:
                        gold_outputs_2 = gold_model_2(**gold_inputs)
                        gold_logits_2 = gold_outputs_2.logits
                        # if normalization:
                        gold_logits_2 = (gold_logits_2 - GM_STATS_2["mean"]) / GM_STATS_2["std"]
                        gold_logits_2 = gold_logits_2.reshape((-1)).detach().cpu().numpy()
                        
                        gold_logits = (gold_logits_1 + gold_logits_2) / 2.0
                    else:
                        gold_logits = gold_logits_1
                        
                    gold_scores.append(gold_logits)

        return np.array(proxy_scores), np.array(gold_scores), perturbed_examples, optim_embeds

    raw_prompts = []
    with open(input_file) as file:
        raw_prompts = json.load(file)
        
    prompts = ["Assistant: ".join(p.split("Assistant: ")[:-1]) + "Assistant: " for p in raw_prompts]
    completions = [p.split("Assistant: ")[-1] for p in raw_prompts]

    all_proxy, all_gold, all_adv_prompts = [],[],[]
    cached_embeds = None
    print("Optimizing PEZ with cached_prompt=",one_prompt)
    for i in tqdm(range(0, N, batch_size)):
        
        raw_batch= raw_prompts[i:i+batch_size]
        prompt_batch = prompts[i:i+batch_size]
        completion_batch = completions[i:i+batch_size]

        proxy_scores, gold_scores, adv_prompts, optim_embeds = run_pez(prompt_batch, completion_batch, rw_tokenizer, 
                                            num_optim_tokens, num_steps, eval_steps, mode=pez_mode, maximize=True, optim_embeds=cached_embeds)
        cached_embeds = optim_embeds if one_prompt else None

        # ========= Raw Scores =========== # 
        inputs = gold_tokenizer(raw_batch, padding=True, return_tensors="pt")
        inputs = {k: v.cuda() for k, v in inputs.items()}
        with torch.no_grad():
            outputs = rw_model(**inputs)
            logits = outputs.logits.reshape((-1)).detach().cpu().numpy()
            logits = (logits - RM_STATS["mean"]) / RM_STATS["std"]

            gold_outputs = gold_model_1(**inputs)
            gold_logits_1 = gold_outputs.logits
            gold_logits_1 = (gold_logits_1 - GM_STATS_1["mean"]) / GM_STATS_1["std"]
            gold_logits_1 = gold_logits_1.reshape((-1)).detach().cpu().numpy()

            if gold_model_2:
                gold_outputs_2 = gold_model_2(**inputs)
                gold_logits_2 = gold_outputs_2.logits
                gold_logits_2 = (gold_logits_2 - GM_STATS_2["std"]) / GM_STATS_2["std"]
                gold_logits_2 = gold_logits_2.reshape((-1)).detach().cpu().numpy()  
                gold_logits = (gold_logits_1 + gold_logits_2) / 2.0
            else:
                gold_logits = gold_logits_1  
        
        proxy_scores = np.vstack((logits, proxy_scores))
        gold_scores = np.vstack((gold_logits, gold_scores))
        adv_prompts = np.vstack((raw_batch, adv_prompts))

        all_proxy.extend(proxy_scores)
        all_gold.extend(gold_scores)
        all_adv_prompts.extend(adv_prompts)
    
    print("Writing logs to ", output_file)
    with open(output_file, 'wb') as f:
        np.save(f, np.array(all_proxy))
        np.save(f, np.array(all_gold))
        np.save(f, np.array(all_adv_prompts))

if __name__ == "__main__":
    fire.Fire(main)