import os
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification

from sentence_transformers.util import (semantic_search,
                                        dot_score,
                                        normalize_embeddings)
import pandas as pd
import numpy as np
import numpy as np
import fire
from tqdm import tqdm
import json


def main(
    reward_model: str = "microsoft/deberta-v3-xsmall",
    reward_checkpoint_path: str =  "/data/private_models/proxy_models/deberta-v3-xsmall_hh_1e-05_32_2.pkl",
    gold_model: str = "microsoft/deberta-v3-large",
    gold_checkpoint_path_1: str = "/data/private_models/proxy_models/gold_ensemble/deberta_large_gold1.pkl",
    gold_checkpoint_path_2: str = "/data/private_models/proxy_models/gold_ensemble/deberta_large_gold2.pkl",
    input_file: str = "/data/private_models/misc/wb_data/pythia6b_inputs_1_train.json",
    N: int = 512,
    num_steps: int = 100,
    eval_steps: int = 5,
    num_optim_tokens: int = 4,
    lr: float = 0.5,
    batch_size: int = 32,
    mode: str = "prepend",
    output_file: str = "",
    one_prompt: bool = False, # one PEZ prompt for the entire dataset
    verbose: bool = False,
    normalization: bool = True,
    
):
    assert output_file and "npy" in output_file, "output_file need to in .npy format"
    assert input_file and "json" in input_file, "input_file need to be a json file"

    if os.path.dirname(output_file): os.makedirs(os.path.dirname(output_file), exist_ok=True)   

    def load_model(model: str, load_path: str):
        config = AutoConfig.from_pretrained(model, num_labels=1)
        model = AutoModelForSequenceClassification.from_pretrained(model, config=config)
        model.load_state_dict(torch.load(load_path), strict=True)
        return model

    rw_tokenizer = AutoTokenizer.from_pretrained(reward_model)
    rw_tokenizer.truncation_side = "left"
    gold_tokenizer = AutoTokenizer.from_pretrained(gold_model)
    gold_tokenizer.truncation_side = "left"
    rw_model = load_model(reward_model, reward_checkpoint_path).cuda().eval()
    gold_model_1 = load_model(gold_model, gold_checkpoint_path_1).cuda().eval()
    
    gold_model_2 = None
    if gold_checkpoint_path_2:
        gold_model_2 = load_model(gold_model, gold_checkpoint_path_2).cuda().eval()
        for p in gold_model_2.parameters(): p.requires_grad = False


    for p in rw_model.parameters(): p.requires_grad = False
    for p in gold_model_1.parameters(): p.requires_grad = False

    metadata = json.load(open("../../reward/normalization_calibration/normalization_coeffs.json"))
    RM_STATS = metadata[reward_checkpoint_path]
    GM_STATS_1 = metadata[gold_checkpoint_path_1]

    if gold_model_2: GM_STATS_2 = metadata[gold_checkpoint_path_2]


    # average_embed_weights = rw_model.deberta.embeddings.word_embeddings.weight.mean(0).data
    with torch.no_grad():
        embeddings = rw_model.get_input_embeddings()(torch.arange(0, rw_tokenizer.vocab_size).long().to(rw_model.device))

    def run_gbda(prompts, completions, tokenizer, num_optim_tokens, num_steps, eval_steps, mode='append', maximize=True, optim_embeds=None):

        multiplier = -1 if maximize else 1
        # ========== orig prompts ========== #
        all_input_tokens = []
        attack_positions = []
        placeholder_tokens = torch.tensor([42 for _ in range(num_optim_tokens)])
        for idx, (prompt, completion) in enumerate(zip(prompts, completions)):
            prompt_tokens = tokenizer(prompt, return_tensors="pt").input_ids[0,:-1] # skip eos
            completion_tokens = tokenizer(completion, return_tensors="pt").input_ids[0,1:-1] # skip bos
            if mode == 'prepend':
                example_tokens = torch.cat([prompt_tokens, placeholder_tokens, completion_tokens])
                attack_positions.append([len(prompt_tokens) + offset for offset in range(num_optim_tokens)])
            elif mode == 'append':
                example_tokens = torch.cat([prompt_tokens, completion_tokens, placeholder_tokens])
                attack_positions.append([len(prompt_tokens) + len(completion_tokens) + offset for offset in range(num_optim_tokens)])
            elif mode == 'random_replace':
                if num_optim_tokens >= len(completion_tokens):
                    example_tokens = torch.cat([prompt_tokens, placeholder_tokens, completion_tokens])
                    attack_positions.append([len(prompt_tokens) + offset for offset in range(num_optim_tokens)])
                else:
                    sampled_idx = np.random.choice(len(completion_tokens), num_optim_tokens, replace=False)
                    example_tokens = torch.cat([prompt_tokens, completion_tokens])
                    attack_positions.append([len(prompt_tokens) + offset for offset in sampled_idx])
            example_tokens = torch.cat([example_tokens, torch.tensor([rw_tokenizer.eos_token_id])])
            all_input_tokens.append(example_tokens)
        
        all_input_tokens = pad_sequence(all_input_tokens, batch_first=True, padding_value=rw_tokenizer.pad_token_id).cuda()
        attention_mask = (all_input_tokens != rw_tokenizer.pad_token_id).cuda()
        all_input_embeds = rw_model.deberta.embeddings.word_embeddings(all_input_tokens).data
        attack_positions = torch.tensor(attack_positions, device='cuda')


        log_coeffs = torch.randn(all_input_tokens.size(0), num_optim_tokens, embeddings.size(0), dtype=embeddings.dtype).squeeze(0).to(rw_model.device)
        log_coeffs.requires_grad = True 
        # ========== setup optimizer and scheduler ========== #
        optimizer = torch.optim.Adam([log_coeffs], lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_steps)
        taus = np.linspace(1, 0.1, num_steps)

        # ========== run optimization ========== #
        min_loss = 1e10
        proxy_scores, gold_scores, perturbed_examples = [], [], []
        for step in range(num_steps):
            
            optimizer.zero_grad()
            
            # ========== compute logits with concatenated optim embeds and target text ========== #
            coeffs = torch.nn.functional.gumbel_softmax(log_coeffs.unsqueeze(0), hard=False, tau=taus[step]) 
            optim_embeds = (coeffs @ embeddings[None, :, :])
            input_embeds = all_input_embeds.clone()
            input_embeds[torch.arange(input_embeds.size(0)).unsqueeze(1), attack_positions] = optim_embeds

            outputs = rw_model(inputs_embeds=input_embeds, attention_mask=attention_mask)
            logits = outputs.logits
            logits = (logits) 
            
            loss = multiplier * logits
            loss = loss.mean()
            
            if loss < min_loss:
                min_loss = loss
                optim_embeds_best = optim_embeds.data.clone()
            if step == 0:
                orig_loss = loss
            if verbose:
                if step % 10 == 0:
                    print('{} {:.3f}'.format(step, logits.mean().item()))
            
            # ========== update optim_embeds ========== #
            loss_bp = loss
            loss_bp.backward()
            optimizer.step()
            scheduler.step()

            # Eval
            if step % eval_steps == 0:
                optim_tokens = torch.argmax(log_coeffs, dim=-1).detach()
                all_input_tokens[torch.arange(input_embeds.size(0)).unsqueeze(1), attack_positions] = optim_tokens
                adv_prompts = gold_tokenizer.batch_decode(all_input_tokens, skip_special_tokens=True)
                
                perturbed_examples.append(adv_prompts)
                gold_inputs = gold_tokenizer(adv_prompts, padding=True, return_tensors="pt")
                gold_inputs = {k: v.cuda() for k, v in gold_inputs.items()}
                with torch.no_grad():
                    outputs = rw_model(**gold_inputs)
                    logits = outputs.logits.reshape((-1)).detach().cpu().numpy()
                    logits = (logits - RM_STATS["mean"]) / RM_STATS["std"]
                    proxy_scores.append(logits)

                    gold_outputs = gold_model_1(**gold_inputs)
                    gold_logits_1 = gold_outputs.logits
                    gold_logits_1 = (gold_logits_1 - GM_STATS_1["mean"]) / GM_STATS_1["std"]
                    gold_logits_1 = gold_logits_1.reshape((-1)).detach().cpu().numpy()

                    if gold_model_2:
                        gold_outputs_2 = gold_model_2(**gold_inputs)
                        gold_logits_2 = gold_outputs_2.logits
                        gold_logits_2 = (gold_logits_2 - GM_STATS_2["mean"]) / GM_STATS_2["std"]
                        gold_logits_2 = gold_logits_2.reshape((-1)).detach().cpu().numpy()
                        
                        gold_logits = (gold_logits_1 + gold_logits_2) / 2.0
                    else:
                        gold_logits = gold_logits_1
                        
                    gold_scores.append(gold_logits)
        return np.array(proxy_scores), np.array(gold_scores), perturbed_examples

    raw_prompts = []
    with open(input_file) as file:
        raw_prompts = json.load(file)
        
    prompts = ["Assistant: ".join(p.split("Assistant: ")[:-1]) + "Assistant: " for p in raw_prompts]
    completions = [p.split("Assistant: ")[-1] for p in raw_prompts]

    all_proxy, all_gold, all_adv_prompts = [],[],[]
    cached_embeds = None
    print("Optimizing PEZ with cached_prompt=",one_prompt)
    for i in tqdm(range(0, N, batch_size)):
        
        prompt_batch = prompts[i:i+batch_size]
        completion_batch = completions[i:i+batch_size]
        proxy_scores, gold_scores, adv_prompts = run_gbda(prompt_batch, completion_batch, rw_tokenizer, 
                                            num_optim_tokens, num_steps, eval_steps, mode=mode, maximize=True, optim_embeds=cached_embeds)

        all_proxy.extend(proxy_scores)
        all_gold.extend(gold_scores)
        all_adv_prompts.extend(adv_prompts)
        
    print("Writing logs to ", output_file)
    with open(output_file, 'wb') as f:
        np.save(f, np.array(all_proxy))
        np.save(f, np.array(all_gold))
        np.save(f, np.array(all_adv_prompts))

if __name__ == "__main__":
    fire.Fire(main)