import torch
import numpy as np
from tqdm import tqdm
import random
from torch.nn import CrossEntropyLoss
from ..baseline import SingleBehaviorRedTeamingMethod
from ..check_refusal_utils import check_refusal_completions
from ..model_utils import get_template, load_model_and_tokenizer, load_vllm_model
from accelerate.utils import find_executable_batch_size
from vllm import SamplingParams
import re
import json
import os

class FewShot(SingleBehaviorRedTeamingMethod):
    use_ray = True # vllm uses ray to handle multi-gpus (tensor_parallel)

    def __init__(self, target_model, attack_model, targets_path, method_template, completion_prefix,
                generation_config, sample_size=1000, n_shots=5, num_steps=5, method_template_contextual=None,
                completion_prefix_contextual=None, **kwargs):
        """
        :param target_model: a dictionary specifying the target model (kwargs to load_model_and_tokenizer)
        :param attack_model: a dictionary specifying the attack model (kwargs to load_vllm_model)
        :param targets_path: the path to the targets JSON file
        :param generation_config: the generation config for attack model
        :param sample_size: the number of samples to generate at each step
        :param n_shots: the number of shots to use for fewshot
        :param num_steps: the number of steps to use
        :param method_template: few-shot prompt template for behaviors w/o context string
        :param completion_prefix: the completion prefix for behaviors w/o context string
        :param method_template_contextual: few-shot prompt template for behaviors w/ context string
        :param completion_prefix_contextual: the completion prefix for behaviors w/ context string
        """

        self.method_template = method_template
        self.completion_prefix = completion_prefix
        self.method_template_contextual = method_template_contextual
        self.completion_prefix_contextual = completion_prefix_contextual
        assert '{behavior}' in self.method_template, \
             "method_template needs to include `{behavior}` format keys"
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)

        # ===== Load Attack Model ======
        allocated_gpus = attack_model.get("num_gpus", 1)
        self.attack_model_template = get_template(**attack_model)['prompt']
        self.attack_model = load_vllm_model(**attack_model)
        self.generation_config = SamplingParams(**generation_config)
        self.generation_config.stop = ['11.'] # dummy stop in case the generation go over a list of 10 as we fix each query to generate 10 cases

        # manually allocate target model into unallocated gpus
        self.model_name_or_path = target_model['model_name_or_path']
        num_gpus = target_model.get("num_gpus", 1)
        max_memory = {i: torch.cuda.mem_get_info(i)[0] if i >= allocated_gpus else 0 for i in range(0, allocated_gpus + num_gpus)}
        self.model, self.tokenizer = load_model_and_tokenizer(**target_model, max_memory=max_memory)
        self.template = get_template(**target_model)['prompt']

        # self._generation_sample_size = min(32, sample_size)  # starting _generation_batch_size, will decrease if go OOM
        self._foward_batch_size = 128 # starting _foward_batch_size to get loss, will decrease if go OOM

        self.sample_size = sample_size
        self.n_shots = n_shots
        self.num_steps = num_steps

    def generate_test_cases_single_behavior(self, behavior_dict, verbose=False, **kwargs):
        # Unpack behavior_dict
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']
        behavior_id = behavior_dict['BehaviorID']
        target = self.behavior_id_to_target[behavior_id]

        attack_model, model = self.attack_model, self.model
        tokenizer = self.tokenizer
        device = model.device
        attack_model_template, template = self.attack_model_template, self.template


        method_template = self.method_template if not context_str else self.method_template_contextual
        completion_prefix = self.completion_prefix if not context_str else self.completion_prefix_contextual
        sample_size = self.sample_size
        num_steps = self.num_steps
        n_shots = self.n_shots
        generation_config = self.generation_config

        query = method_template.format(behavior=behavior, context=context_str) # will not throw error for self.method_template w/o context 
        query = attack_model_template.format(instruction=query)
        candidates = []
        logs = []

        seen_candidates = set()
        prev_candidates = None
        for step in tqdm(range(num_steps)):
            # If first step, will do zeroshot to get variations
            # Else will use top {n_shots} lowest loss variations as examples for fewshot
            fs_examples = candidates[:n_shots]

            # check if variations are converged
            if fs_examples == prev_candidates:
                break

            prev_candidates = fs_examples

            starting_idx = len(fs_examples) + 1
            shots = [f'{idx}. {s["candidate"]}' for idx, s in enumerate(fs_examples, 1)]

            # === construct completion_prefix query===
            # Sure, here's 10 variants of '{behavior}': 
            # 1. ...
            # 2. ...
            # ...
            # {n_shots}. ...
            # {starting_idx}.
            input_query = query + '\n'.join([completion_prefix.format(behavior=behavior)] + shots + [f'{starting_idx}. '])
            current_step_candidates = []

            # === Sampling until got {sample_size} new variations ===
            attempt = 0
            # print("====input_query====")
            # print(input_query)
            while len(current_step_candidates) < sample_size and attempt <= 10:
                input_query_b = [input_query] * round((sample_size - len(current_step_candidates)) / n_shots) 
                outputs = self._vllm_generate(attack_model, input_query_b, generation_config)

                generated_candidates = []
                for o in outputs:
                    # prepend a starting idx back to outputs to extract_numbered_lines
                    o = f'{starting_idx}. ' + o.strip()
                    generated_candidates.extend(extract_numbered_lines(o))

                # === Dedup, filtered and update seen candidates ===
                filtered_candidates = []
                for g in generated_candidates:
                    if g in seen_candidates: continue
                    if not g.strip() or len(g) < 32: continue
                    filtered_candidates.append(g)
                    seen_candidates.add(g)
                # === Add current candidates to list for this step ===
                current_step_candidates.extend(filtered_candidates)
                attempt += 1
            
            # ==== break if current_step_candidates empty after filtering ====
            if len(current_step_candidates) == 0:
                break
            
            # ===== Calculating Loss for each candidates =====
            loss_function = find_executable_batch_size(self._get_loss_bs, self._foward_batch_size)
            formated_candidates = []
            for c in current_step_candidates:
                instruction = f"{context_str}\n\n---\n\n{c}" if context_str else c
                instruction = template.format(instruction=instruction)
                formated_candidates.append(instruction)
            
            candidates_loss = loss_function(formated_candidates, target)
            for c, loss in zip(current_step_candidates, candidates_loss):
                candidates.append(dict(candidate=c, loss=loss.item()))

            # ===== Sort lowest loss candidates as shot examples for next step =====
            candidates = sorted(candidates, key=lambda x: x['loss'])
            
            # == Add current step to logs ==
            logs.append(fs_examples)

            # ===== Check if lowest loss candidates jailbreak ======
            test_case = candidates[0]['candidate']
            if context_str:
                test_case = f"{context_str}\n\n---\n\n{test_case}"
            _input = template.format(instruction=test_case)
            is_refusal, completions, _ = check_refusal_completions(model, tokenizer, inputs=[_input])
            # print("====test_case==== Loss:", candidates[0]['loss'])
            # print(_input)
            # print("====completions====")
            # print(completions[0])
            if not is_refusal[0]:
                break
            
        return test_case, logs
    
    def _vllm_generate(self, model, queries, generation_config):
        outputs = model.generate(queries, sampling_params=generation_config, use_tqdm=False)
        generations = [o.outputs[0].text.strip() for o in outputs]
        return generations
    
    def _get_loss_bs(self, batch_size, inputs, target):
        tokenizer = self.tokenizer
        model = self.model

        # === Update _foward_batch_size for later calls as this function is called by find_executable_batch_size ===
        if batch_size != self._foward_batch_size:
            self._foward_batch_size = batch_size
            print(f"INFO: setting _foward_batch_size={self._foward_batch_size}")

        all_candidate_losses = []
        target_ids = tokenizer(target, return_tensors='pt', add_special_tokens=False).input_ids.to(model.device)
        for i in range(0, len(inputs), batch_size):
            inputs_b = [_input + target for _input in inputs[i:i+batch_size]]
            encoded_b = tokenizer(inputs_b, return_tensors='pt', padding=True).to(model.device)
            
            with torch.no_grad():
                outputs = model(**encoded_b)
            batch_logits = outputs.logits

            # compute loss
            # Shift so that tokens < n predict n
            shift_logits = batch_logits[:, -target_ids.size(1)-1:-1, :].contiguous()
            shift_labels = target_ids.repeat(shift_logits.size(0), 1)
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss.view(shift_logits.size(0), -1).mean(dim=1)
            all_candidate_losses.append(loss)

        return torch.cat(all_candidate_losses)
        
# ====== Helper Functions ===== #

### Extract list of numbers: 1. .... into list
def extract_numbered_lines(text):
    lines = text.split("\n")
    numbered_lines = set([re.sub(r'^\s*\d+\.\s*"?(.*?)"?\s*$', r'\1', line) for line in lines if re.match(r'^\s*\d', line)])
    numbered_lines = [l.replace("Request: ", "").strip() for l in numbered_lines]
    numbered_lines = [re.sub(r'^\d+\.\s*', '', l) for l in numbered_lines] # some lines still have `{number}. ...`
    return numbered_lines
