import copy
import re

import ray

from ..model_utils import get_template
from .common import extract_json
from .judges import GPTJudge, OpenSourceJudge
from .language_models import GPT, VLLM, Claude, Gemini, HuggingFace


def load_PAIR_models(
    attackLM_args,
    targetLM_args,
    judgeLM_args,
    attack_max_n_tokens,
    target_max_n_tokens,
    judge_max_n_tokens,
    max_retries,
    attack_temp,  # default 1
):
    ray.init(ignore_reinit_error=True)

    a_name, t_name, j_name = [
        args["model_name_or_path"] for args in [attackLM_args, targetLM_args, judgeLM_args]
    ]

    # ====== Load attack model =========
    attackLM = AttackLM(
        max_n_tokens=attack_max_n_tokens,
        max_retries=max_retries,
        temperature=attack_temp,
        **attackLM_args,
    )

    # ====== Load target model =========
    preloaded_target_model = None
    if t_name == a_name:
        preloaded_target_model = attackLM.model
        targetLM_args = attackLM_args
    targetLM = TargetLM(
        max_n_tokens=target_max_n_tokens,
        temperature=0.0,
        preloaded_model=preloaded_target_model,
        **targetLM_args,
    )

    # ====== Load judge model =========
    preloaded_judge_model = None
    if j_name == a_name:
        preloaded_judge_model = attackLM.model
        judgeLM_args = attackLM_args
    elif j_name == t_name:
        preloaded_judge_model = targetLM.model
        judgeLM_args = targetLM_args

    judgeLM = load_judge(
        max_n_tokens=judge_max_n_tokens, preloaded_model=preloaded_judge_model, **judgeLM_args
    )
    return attackLM, targetLM, judgeLM


class AttackLM:
    """
    Base class for attacker language models.

    Generates attacks for conversations using a language model. The self.model attribute contains the underlying generation model.
    """

    def __init__(
        self,
        max_n_tokens: int,
        max_retries: int,
        temperature: float,
        **model_kwargs,
    ):
        self.model_name_or_path = model_kwargs["model_name_or_path"]
        self.temperature = temperature
        self.max_n_tokens = max_n_tokens
        self.max_retries = max_retries

        self.template = get_template(**model_kwargs, return_fschat_conv=True)
        self.model = load_indiv_model(**model_kwargs)

        # check if model is an opensource model which is a RayActor
        self.use_ray = isinstance(self.model, ray.actor.ActorHandle)

    def get_attack(self, convs_list, prompts_list):
        """
        Generates responses for a batch of conversations and prompts using a language model.
        Only valid outputs in proper JSON format are returned. If an output isn't generated
        successfully after max_retries, it's returned as None.

        Parameters:
        - convs_list: List of conversation objects.
        - prompts_list: List of prompts corresponding to each conversation.

        Returns:
        - List of generated outputs (dictionaries) or None for failed generations.
        """

        assert len(convs_list) == len(
            prompts_list
        ), "Mismatch between number of conversations and prompts."

        batchsize = len(convs_list)
        indices_to_regenerate = list(range(batchsize))
        valid_outputs = [None] * batchsize

        # Initalize the attack model's generated output to match format
        if len(convs_list[0].messages) == 0:
            init_message = """{\"improvement\": \"\",\"prompt\": \""""
        else:
            init_message = """{\"improvement\": \""""

        full_prompts = []
        # Add prompts and initial seeding messages to conversations (only once)
        # print("CONVS ---- >", convs_list)
        for conv, prompt in zip(convs_list, prompts_list):
            conv.append_message(conv.roles[0], prompt)
            # Get prompts
            if "gpt" in self.model_name_or_path:
                full_prompts.append(conv.to_openai_api_messages())
            else:
                conv.append_message(conv.roles[1], init_message)
                if conv.sep2 is not None:
                    trunc_len = -len(conv.sep2)
                else:
                    trunc_len = -1
                full_prompts.append(conv.get_prompt()[: trunc_len])

        for attempt in range(self.max_retries):
            # Subset conversations based on indices to regenerate
            full_prompts_subset = [full_prompts[i] for i in indices_to_regenerate]

            generate_function = (
                self.model.batched_generate.remote if self.use_ray else self.model.batched_generate
            )
            outputs_list = generate_function(
                full_prompts_subset,
                max_n_tokens=self.max_n_tokens,
                temperature=self.temperature,
                stop_tokens=["}"],
            )

            if self.use_ray:
                outputs_list = ray.get(outputs_list)
            # Check for valid outputs and update the list
            new_indices_to_regenerate = []

            for i, full_output in enumerate(outputs_list):
                orig_index = indices_to_regenerate[i]
                if "gpt" not in self.model_name_or_path:
                    full_output = init_message + full_output
                    if full_output[-1] != "}":
                        full_output += "}"

                attack_dict, json_str = extract_json(full_output)


                if self.apply_filter_partial is not None and attack_dict is not None:
                    print('APPLYING FILTER to the' , attack_dict["prompt"])
                    passed_filter, metric_out = self.apply_filter_partial(attack_dict["prompt"])
                    if not passed_filter:
                        print('THE PROMPT DID NOT PASS THE FILTER', attack_dict["prompt"])
                        print(metric_out)
                        attack_dict = None

                if attack_dict is not None:
                    valid_outputs[orig_index] = attack_dict
                    convs_list[orig_index].update_last_message(
                        json_str
                    )  # Update the conversation with valid generation
                else:
                    new_indices_to_regenerate.append(orig_index)

            # Update indices to regenerate for the next iteration
            indices_to_regenerate = new_indices_to_regenerate

            # If all outputs are valid, break
            if not indices_to_regenerate:
                break

        if any([output for output in valid_outputs if output is None]):
            print(f"Failed to generate output after {self.max_retries} attempts. Terminating.")
        return valid_outputs


class TargetLM:
    """
    Base class for target language models.

    Generates responses for prompts using a language model. The self.model attribute contains the underlying generation model.
    """

    def __init__(
        self,
        max_n_tokens: int,
        temperature: float,
        preloaded_model: object = None,
        **model_kwargs,
    ):

        self.model_name_or_path = model_kwargs["model_name_or_path"]
        self.temperature = temperature
        self.max_n_tokens = max_n_tokens

        self.template = get_template(**model_kwargs, return_fschat_conv=True)
        self.model = preloaded_model if preloaded_model else load_indiv_model(**model_kwargs)

        # check if model is an opensource model which is a RayActor
        self.use_ray = isinstance(self.model, ray.actor.ActorHandle)

    def get_response(self, prompts_list, context_str=None):
        batchsize = len(prompts_list)
        convs_list = [copy.deepcopy(self.template) for _ in range(batchsize)]
        full_prompts = []
        for conv, prompt in zip(convs_list, prompts_list):
            if context_str:
                prompt = f"{context_str}\n\n---\n\n{prompt}"

            conv.append_message(conv.roles[0], prompt)
            if "gpt-" in self.model_name_or_path:
                # Openai does not have separators
                full_prompts.append(conv.to_openai_api_messages())
            elif "gemini-" in self.model_name_or_path:
                full_prompts.append(conv.messages[-1][1])
            else:
                conv.append_message(conv.roles[1], None)
                print("THIS PLACE", conv.get_prompt())
                full_prompts.append(conv.get_prompt())

        generate_function = (
            self.model.batched_generate if not self.use_ray else self.model.batched_generate.remote
        )
        outputs_list = generate_function(
            full_prompts,
            max_n_tokens=self.max_n_tokens,
            temperature=self.temperature,
            stop_tokens=["}"],
        )
        return outputs_list if not self.use_ray else ray.get(outputs_list)


def load_judge(model_name_or_path, max_n_tokens, preloaded_model=None, token=None, **model_kwargs):
    if "gpt" in model_name_or_path:
        judge_model = GPT(model_name=model_name_or_path, token=token)
        return GPTJudge(model_name_or_path, max_n_tokens, judge_model)
    else:
        if preloaded_model is None:
            preloaded_model = load_indiv_model(
                model_name_or_path=model_name_or_path, **model_kwargs
            )
        return OpenSourceJudge(model_name_or_path, max_n_tokens, judge_model=preloaded_model)


def load_indiv_model(model_name_or_path, use_vllm=False, token=None, **model_kwargs):
    if "gpt-" in model_name_or_path:
        lm = GPT(model_name_or_path, token=token)
    elif "claude-" in model_name_or_path:
        lm = Claude(model_name_or_path, token=token)
    elif "gemini-" in model_name_or_path:
        lm = Gemini(model_name_or_path, token=token)
    else:
        # We use ray to manage opensource models here
        num_gpus = model_kwargs.get("num_gpus", 1)
        if use_vllm:
            # need to init as ray actors to load multiple vllm models in one environment (Related issue: https://github.com/vllm-project/vllm/issues/1676)
            # vllm.engine.ray_utils.py will allocate num_gpus for tensor_parallel (num_gpus) > 1 so we don't pass in num_gpus
            options = {"num_gpus": 1} if num_gpus == 1 else {}
            lm = VLLM.options(**options).remote(model_name_or_path, token=token, **model_kwargs)
        else:
            lm = (
                ray.remote(HuggingFace)
                .options(num_gpus=num_gpus)
                .remote(model_name_or_path, token=token, **model_kwargs)
            )
        ray.get(lm.is_initialized.remote())  # init ray model
    return lm
