# TODO: support multi-node; current implementation with Ray only work for 1 node
import os
import json
from typing import List
import random
import torch
from tqdm import tqdm
from ..baseline import RedTeamingMethod
# from ..eval_utils import jailbroken
from .gcg_ray_actors import GCGModelWorker, sample_control, get_nonascii_toks
import ray
# ============================== EnsembleGCG CLASS DEFINITION ============================== #
class EnsembleGCG(RedTeamingMethod):
    use_ray = True

    def __init__(self,
                 target_models,
                 num_steps=50,
                 adv_string_init="! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
                 allow_non_ascii=False,
                 search_width=512,
                 use_prefix_cache=True,
                 targets_path=None,
                 early_stopping=False,
                 eval_steps=10,
                 num_eval_behaviors=10,
                 eval_with_generate=True,
                 progressive_behaviors=True,
                 **kwargs):
        self.num_steps = num_steps
        self.adv_string_init = adv_string_init
        self.allow_non_ascii = allow_non_ascii
        self.search_width = search_width
        self.use_prefix_cache = use_prefix_cache

        if targets_path is None:
            raise ValueError("targets_path must be specified")
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)
        self.targets_path = targets_path

        self.early_stopping = early_stopping
        self.eval_steps = eval_steps
        self.num_eval_behaviors = num_eval_behaviors
        self.eval_with_generate = eval_with_generate
        self.progressive_behaviors = progressive_behaviors

        ## Init workers accross gpus with ray
        num_gpus_list = [m.get('num_gpus', 1) for m in target_models]
        target_models = [m['target_model'] for m in target_models]
        self.models = target_models

        self.gcg_models = []
        # Allocate GPUs to workers based on their requirements of gpu
        for model_config, required_gpus in zip(target_models, num_gpus_list):            
            # Create a GCGModelWorker instance with allocated GPUs
            worker_actor = GCGModelWorker.options(num_gpus=required_gpus).remote(search_width=search_width,
                                                                                 use_prefix_cache=use_prefix_cache, **model_config)
            self.gcg_models.append(worker_actor)

        initialization_models = [w.is_initialized.remote() for w in self.gcg_models]
        print("\n".join(ray.get(initialization_models)))

        self.tokenizers = ray.get([w.get_attribute_.remote('tokenizer') for w in self.gcg_models])
        # Check that all tokenizers are the same type
        assert len(set([type(t) for t in self.tokenizers])) == 1, "All tokenizers must be the same type"
        self.sample_tokenizer = self.tokenizers[0]  # Used for initializing optimized tokens and decoding result; assumes all tokenizers are the same
        self.not_allowed_tokens = torch.unique(None if allow_non_ascii else get_nonascii_toks(self.sample_tokenizer))

    def generate_test_cases(self, behaviors, targets=None, verbose=False):
        """
        Generate test cases for the behaviors (i.e., prompts meant to elicit the behaviors)

        :param behaviors: a list of behaviors to generate test cases for
        :param targets: a list of targets to be optimized against (will be default to gcg target style with `Sure, here's `)
        :param overwrite: whether to overwrite existing saved test cases
        :param verbose: whether to print progress
        """

        # ===================== generate test cases ===================== #
        # ========== Behavior meta data ==========
        behavior_dicts = behaviors
        behaviors = [b['Behavior'] for b in behavior_dicts]
        context_str_list = [b['ContextString'] for b in  behavior_dicts]
        behavior_id_list = [b['BehaviorID'] for b in behavior_dicts]

        targets = [self.behavior_id_to_target[behavior_id] for behavior_id in behavior_id_list]
        behaviors = [b + ' ' for b in behaviors]  # add a space before the optimized string

        # Adding context string to behavior string if available
        tmp = []
        for behavior, context_str in zip(behaviors, context_str_list):
            if context_str:
                behavior = f"{context_str}\n\n---\n\n{behavior}"
            tmp.append(behavior)
        behaviors = tmp

        total_behaviors = len(behaviors)
        num_behaviors = 1 if self.progressive_behaviors else total_behaviors
        all_behavior_indices = list(range(total_behaviors))

        gcg_models = self.gcg_models
        # Init behaviors, targets, embeds
        ray.get([w.init_behaviors_and_targets.remote(behaviors, targets) for w in gcg_models])        

        optim_ids = self.sample_tokenizer(self.adv_string_init, return_tensors="pt", add_special_tokens=False)['input_ids']
        num_optim_tokens = optim_ids.shape[1]

        logs = []
        all_losses = []
        all_test_cases = []
        steps_since_last_added_behavior = 0
        
        # ========== run optimization ========== #
        for i in tqdm(range(self.num_steps)):
            # select behaviors to use in search step
            current_behavior_indices = all_behavior_indices[:num_behaviors]

            # ========= Get normalized token gradients for all models and average across models ========= #
            futures = []
            for worker in gcg_models:
                futures.append(worker.get_token_gradient.remote(optim_ids, num_optim_tokens, current_behavior_indices))
            grads = ray.get(futures)

            # clipping extra tokens
            min_embedding_value = min(grad.shape[2] for grad in grads)
            grads = [grad[:, :, :min_embedding_value] for grad in grads]
            token_grad = sum(grads) / len(grads)

            # ========= Select candidate replacements ========= #
            sampled_top_indices = sample_control(optim_ids.squeeze(0), token_grad.squeeze(0), self.search_width,
                                                 topk=256, temp=1, not_allowed_tokens=self.not_allowed_tokens)
            
            # ========= Filter sampled_top_indices that aren't the same after detokenize->retokenize ========= #
            sampled_top_indices_text = self.sample_tokenizer.batch_decode(sampled_top_indices)
            new_sampled_top_indices = []
            count = 0
            for j in range(len(sampled_top_indices_text)):
                # tokenize again
                tmp = self.sample_tokenizer(sampled_top_indices_text[j], return_tensors="pt",
                                            add_special_tokens=False).to(sampled_top_indices.device)['input_ids'][0]
                # if the tokenized text is different, then set the sampled_top_indices to padding_top_indices
                if not torch.equal(tmp, sampled_top_indices[j]):
                    count += 1
                    continue
                else:
                    new_sampled_top_indices.append(sampled_top_indices[j])
            
            sampled_top_indices = torch.stack(new_sampled_top_indices)
            
            if count >= self.search_width // 2:
                print('\nLots of removals:', count)

            # ========= Evaluate candidate optim ========= #
            futures = []
            for worker in gcg_models:
                future = worker.compute_loss_on_candidates_batch.remote(sampled_top_indices, current_behavior_indices)
                futures.append(future)

            losses = ray.get(futures)
            losses = torch.stack(losses)
            
            # mean across all behaviors
            model_losses = losses.mean(1)

            # mean across all models
            behavior_losses = losses.mean(0)

            # mean across all behaviors and models 
            avg_losses = losses.mean(1).mean(0)
            selected_id = avg_losses.argmin().item()

            # ========== Update the optim_ids with the best candidate ========== #
            optim_ids = sampled_top_indices[selected_id].unsqueeze(0)
            
            # ========== Get test cases and append to lists for logging ========== #
            optim_str = self.sample_tokenizer.decode(optim_ids[0])
            test_cases = {b_id: [b + optim_str] for b, b_id in zip(behaviors, behavior_id_list)}
            all_test_cases.append(test_cases)
            current_loss = avg_losses.min().item()
            all_losses.append(current_loss)
            
            optim_str = self.sample_tokenizer.decode(optim_ids[0])


            if verbose and ((i % self.eval_steps == 0) or (i == self.num_steps - 1)):
                _avg_loss = avg_losses.min().item()

                print("\n" + "="*50)
                print(f'\nStep {i} | Optimized Suffix: {optim_str} | Avg Loss: {_avg_loss}')
                model_losses = [f"->Model: {self.models[i]['model_name_or_path']} | Loss {model_losses[i].min().item()}" for i in range(len(losses))]

                print("\n".join(model_losses))
                print()
            
            # if i % 20 == 19:
            #     print('Adding behavior to list of behaviors to optimize')
            #     num_behaviors += 1
            
            steps_since_last_added_behavior += 1

            # if all behavior_losses[:, selected_id] < 0.1, add another behavior to optimize
            if self.progressive_behaviors:
                if all(behavior_losses[:, selected_id] < 0.2):
                    print('\nAdding behavior to list of behaviors to optimize (case 1)\n')
                    num_behaviors += 1
                    steps_since_last_added_behavior = 0
                elif steps_since_last_added_behavior == 100:
                    print('\nAdding behavior to list of behaviors to optimize (case 2)\n')
                    num_behaviors += 1
                    steps_since_last_added_behavior = 0

                # generate_outputs = ""
                # if self.eval_with_generate or self.progressive_behaviors:
                #     if not self.progressive_behaviors:
                #         # TODO: change generate to ... take in a string?
                #         generate_outputs = ray.get(selected_worker.generate.remote(optim_str, selected_behavior_idx))
                #         print("input:", selected_behavior + optim_str)
                #         print("output:", generate_outputs)
                #     else: 
                #         # get generations for all models and all behaviors
                #         futures = []
                #         for worker in gcg_models:
                #             for behavior_idx in current_behavior_indices:
                #                 generate_output = worker.generate.remote(optim_str, behavior_idx)
                #                 futures.append(generate_output)
                            
                #         all_outputs = ray.get(futures)

                #         # outputs to log
                #         generate_outputs = all_outputs[selected_id*num_behaviors + selected_behavior_idx]
                #         print("input:", selected_behavior + optim_str)
                #         print("output:", generate_outputs)
                #         print(f"Optimizing {num_behaviors} behaviors")
                #         print(f"Most recently added behavior: {behaviors[num_behaviors-1]}")

                #         # test all behaviors
                #         jbs = [jailbroken(gen_string) for gen_string in all_outputs]
                #         if all(jbs):
                #             if num_behaviors < total_behaviors:
                #                 print(f"All current behaviors jailbroken - adding behavior {num_behaviors}")
                #                 num_behaviors += 1
                #         else:
                #             false_indices = [(i//num_behaviors, i%num_behaviors) for i, b in enumerate(jbs) if not b]
                #             print(f"Indices (model #, behavior #) not jailbroken: {false_indices}")

                # (n_behaviors, n_models, search_width) -> (n_eval_behaviors, search_width)
                # behavior_losses = torch.mean(torch.stack([l[0] for l in results], dim=1)[:self.num_eval_behaviors], dim=1)
                # log_dict = dict(step=i,
                #                 num_behaviors=num_behaviors,
                #                 behavior=selected_behavior,
                #                 optim_str=optim_str,
                #                 selected_model=selected_model,
                #                 selected_model_output=generate_outputs,
                #                 avg_loss=_avg_loss,
                #                 losses={self.models[i]['model_name_or_path']: losses[i].min().item() for i in range(len(losses))},
                #                 behavior_losses={behaviors[i]: behavior_losses[i].min().item() for i in range(len(behavior_losses))},
                #             )

                # logs.append(log_dict)
        
        logs = {'final_loss': current_loss, 'all_losses': all_losses, 'all_test_cases': all_test_cases, 'num_behaviors': num_behaviors}

        return test_cases, logs

    
    # ===== Override saving methods ====
    # Saves test cases for all behaviors into one test_cases.json file
    def get_output_file_path(self, save_dir, behavior_id, file_type, run_id=None):
        run_id = f"_{run_id}" if run_id else ""
        return os.path.join(save_dir, f"{file_type}{run_id}.json")
    
    def save_test_cases(self,
            save_dir,
            test_cases,
            logs=None,
            method_config=None,
            run_id=None):

        test_cases_save_path = self.get_output_file_path(save_dir, None, 'test_cases', run_id)
        logs_save_path = self.get_output_file_path(save_dir, None, 'logs', run_id)
        method_config_save_path = self.get_output_file_path(save_dir, None, 'method_config', run_id)

        os.makedirs(os.path.dirname(test_cases_save_path), exist_ok=True)
        if test_cases is not None:
            with open(test_cases_save_path, 'w', encoding='utf-8') as f:
                json.dump(test_cases, f, indent=4)

        if logs is not None:
            with open(logs_save_path, 'w', encoding='utf-8') as f:
                json.dump(logs, f, indent=4)
        
        if method_config is not None:
            # mask token or api_key
            self._replace_tokens(method_config)
            with open(method_config_save_path, 'w', encoding='utf-8') as f:
                method_config["dependencies"] = {l.__name__: l.__version__ for l in self.default_dependencies}
                json.dump(method_config, f, indent=4)