import copy
import re
from .common import extract_json
from .judges import GPTJudge, OpenSourceJudge
from .language_models import GPT, Claude, Gemini, HuggingFace, VLLM, TogetherAPI
from ..model_utils import get_template
import ray

def load_PAIR_models(attackLM_args,
                    targetLM_args,
                    judgeLM_args,
                    attack_max_n_tokens,
                    target_max_n_tokens,
                    judge_max_n_tokens,
                    max_retries,
                    attack_temp, # default 1
                    ):
    ray.init(ignore_reinit_error=True)

    a_name, t_name, j_name = [args['model_name_or_path'] for args in [attackLM_args, targetLM_args, judgeLM_args]] 
    
    # ====== Load attack model =========
    attackLM = AttackLM(max_n_tokens=attack_max_n_tokens, 
                        max_retries=max_retries, 
                        temperature=attack_temp,
                        **attackLM_args
                        )
    
    # ====== Load target model =========
    preloaded_target_model = None
    if t_name == a_name:
        preloaded_target_model = attackLM.model
        targetLM_args = attackLM_args
    targetLM = TargetLM(max_n_tokens=target_max_n_tokens,
                        temperature=0.0,
                        preloaded_model=preloaded_target_model,
                        **targetLM_args
                        )
    
    # ====== Load judge model =========
    preloaded_judge_model = None
    if j_name == a_name: 
        preloaded_judge_model = attackLM.model
        judgeLM_args = attackLM_args
    elif j_name == t_name: 
        preloaded_judge_model = targetLM.model
        judgeLM_args = targetLM_args

    judgeLM = load_judge(max_n_tokens=judge_max_n_tokens, preloaded_model=preloaded_judge_model, **judgeLM_args)
    return attackLM, targetLM, judgeLM

class AttackLM():
    """
        Base class for attacker language models.
        
        Generates attacks for conversations using a language model. The self.model attribute contains the underlying generation model.
    """
    def __init__(self, 
                max_n_tokens: int, 
                max_retries: int, 
                temperature: float,
                **model_kwargs, 
                ):
        self.model_name_or_path = model_kwargs['model_name_or_path']
        self.temperature = temperature
        self.max_n_tokens = max_n_tokens
        self.max_retries = max_retries
        
        self.template = get_template(**model_kwargs, return_fschat_conv=True)
        self.model = load_indiv_model(**model_kwargs)

        # check if model is an opensource model which is a RayActor
        self.use_ray = isinstance(self.model, ray.actor.ActorHandle)

    def get_attack(self, convs_list, prompts_list):
        """
        Generates responses for a batch of conversations and prompts using a language model. 
        Only valid outputs in proper JSON format are returned. If an output isn't generated 
        successfully after max_retries, it's returned as None.
        
        Parameters:
        - convs_list: List of conversation objects.
        - prompts_list: List of prompts corresponding to each conversation.
        
        Returns:
        - List of generated outputs (dictionaries) or None for failed generations.
        """
        
        assert len(convs_list) == len(prompts_list), "Mismatch between number of conversations and prompts."
        
        batchsize = len(convs_list)
        indices_to_regenerate = list(range(batchsize))
        valid_outputs = [None] * batchsize

        # Initalize the attack model's generated output to match format
        if len(convs_list[0].messages) == 0:
            init_message = """{\"improvement\": \"\",\"prompt\": \""""
        else:
            init_message = """{\"improvement\": \"""" 

        full_prompts = []
        # Add prompts and initial seeding messages to conversations (only once)
        for conv, prompt in zip(convs_list, prompts_list):
            conv.append_message(conv.roles[0], prompt)
            # Get prompts
            if "gpt" in self.model_name_or_path:
                full_prompts.append(conv.to_openai_api_messages())
            else:
                conv.append_message(conv.roles[1], init_message)
                full_prompts.append(conv.get_prompt()[:-len(conv.sep2)])
            
        for attempt in range(self.max_retries):
            # Subset conversations based on indices to regenerate
            full_prompts_subset = [full_prompts[i] for i in indices_to_regenerate]

            generate_function = self.model.batched_generate.remote if self.use_ray else self.model.batched_generate
            outputs_list = generate_function(full_prompts_subset,
                                            max_n_tokens = self.max_n_tokens,  
                                            temperature = self.temperature,
                                            stop_tokens = ["}"])
            
            if self.use_ray:
                outputs_list = ray.get(outputs_list)
            # Check for valid outputs and update the list
            new_indices_to_regenerate = []

            for i, full_output in enumerate(outputs_list):
                orig_index = indices_to_regenerate[i]
                if "gpt" not in self.model_name_or_path:
                    full_output = init_message + full_output
                    if full_output[-1] != "}": full_output += "}"

                attack_dict, json_str = extract_json(full_output)
                
                if attack_dict is not None:
                    valid_outputs[orig_index] = attack_dict
                    convs_list[orig_index].update_last_message(json_str)  # Update the conversation with valid generation
                else:
                    new_indices_to_regenerate.append(orig_index)
            
            # Update indices to regenerate for the next iteration
            indices_to_regenerate = new_indices_to_regenerate
            
            # If all outputs are valid, break
            if not indices_to_regenerate:
                break
        
        if any([output for output in valid_outputs if output is None]):
            print(f"Failed to generate output after {self.max_retries} attempts. Terminating.")
        return valid_outputs

class TargetLM():
    """
        Base class for target language models.
        
        Generates responses for prompts using a language model. The self.model attribute contains the underlying generation model.
    """
    def __init__(self, 
            max_n_tokens: int, 
            temperature: float,
            preloaded_model: object = None,
            **model_kwargs, 
            ):
        
        self.model_name_or_path = model_kwargs['model_name_or_path']
        self.temperature = temperature
        self.max_n_tokens = max_n_tokens

        self.template = get_template(**model_kwargs, return_fschat_conv=True)
        self.model = preloaded_model if preloaded_model else load_indiv_model(**model_kwargs)
        
        # check if model is an opensource model which is a RayActor
        self.use_ray = isinstance(self.model, ray.actor.ActorHandle)

    def get_response(self, prompts_list, context_str=None):
        batchsize = len(prompts_list)
        convs_list = [copy.deepcopy(self.template) for _ in range(batchsize)]
        full_prompts = []
        for conv, prompt in zip(convs_list, prompts_list):
            if context_str:
                prompt = f"{context_str}\n\n---\n\n{prompt}"
                
            conv.append_message(conv.roles[0], prompt)
            if "gpt-" in self.model_name_or_path:
                # Openai does not have separators
                full_prompts.append(conv.to_openai_api_messages())
            elif "gemini-" in self.model_name_or_path:
                full_prompts.append(conv.messages[-1][1])
            else:
                conv.append_message(conv.roles[1], None) 
                full_prompts.append(conv.get_prompt())

        generate_function = self.model.batched_generate if not self.use_ray else self.model.batched_generate.remote
        outputs_list = generate_function(full_prompts, 
                                        max_n_tokens = self.max_n_tokens,  
                                        temperature = self.temperature,
                                        stop_tokens = ["}"]
                                        )
        return outputs_list if not self.use_ray else ray.get(outputs_list) 

def load_judge(model_name_or_path, max_n_tokens, preloaded_model=None, token=None, **model_kwargs):
    if "gpt" in model_name_or_path:
        judge_model = GPT(model_name=model_name_or_path, token=token) 
        return GPTJudge(model_name_or_path, max_n_tokens, judge_model)
    else:
        if preloaded_model is None:
            preloaded_model = load_indiv_model(model_name_or_path=model_name_or_path, **model_kwargs)
        return OpenSourceJudge(model_name_or_path, max_n_tokens, judge_model=preloaded_model)

def load_indiv_model(model_name_or_path, use_vllm=False, token=None, **model_kwargs):
    if 'gpt-' in model_name_or_path:
        lm = GPT(model_name_or_path, token=token)
    elif 'claude-' in model_name_or_path:
        lm = Claude(model_name_or_path, token=token)
    elif 'gemini-' in model_name_or_path:
        lm = Gemini(model_name_or_path, token=token)
    elif 'Mixtral-' in model_name_or_path:
        lm = TogetherAPI(model_name_or_path, api_key=token)
    else:
        # We use ray to manage opensource models here
        num_gpus = model_kwargs.get("num_gpus", 1)
        if use_vllm:
            # need to init as ray actors to load multiple vllm models in one environment (Related issue: https://github.com/vllm-project/vllm/issues/1676)
            # vllm.engine.ray_utils.py will allocate num_gpus for tensor_parallel (num_gpus) > 1 so we don't pass in num_gpus
            options = {"num_gpus": 1} if num_gpus == 1 else {}
            lm = VLLM.options(**options).remote(model_name_or_path, token=token, **model_kwargs)
        else:
            lm = ray.remote(HuggingFace).options(num_gpus=num_gpus).remote(model_name_or_path, token=token, **model_kwargs)
        ray.get(lm.is_initialized.remote()) # init ray model
    return lm
