import gc
import numpy as np
import random
from tqdm import tqdm
import re
from accelerate.utils import find_executable_batch_size
import sys
import time
from .mutate_models import GPT, Synonyms, VLLM, HuggingFace, TogetherAPI

### GA ###
def autodan_sample_control(control_prefixes, score_list, num_elites, batch_size, crossover=0.5,
                           num_points=5, mutation=0.01, mutate_model=None, reference=None, if_softmax=True, no_mutate_words=[]):
    score_list = [-x for x in score_list]
    # Step 1: Sort the score_list and get corresponding control_prefixes
    sorted_indices = sorted(range(len(score_list)), key=lambda k: score_list[k], reverse=True)
    sorted_control_prefixes = [control_prefixes[i] for i in sorted_indices]

    # Step 2: Select the elites
    elites = sorted_control_prefixes[:num_elites]

    # Step 3: Use roulette wheel selection for the remaining positions
    parents_list = roulette_wheel_selection(control_prefixes, score_list, batch_size - num_elites, if_softmax)

    # Step 4: Apply crossover and mutation to the selected parents
    mutated_offspring = apply_crossover_and_mutation(parents_list, crossover_probability=crossover,
                                                     num_points=num_points,
                                                     mutation_rate=mutation, 
                                                     reference=reference,
                                                     mutate_model=mutate_model,
                                                     no_mutate_words=no_mutate_words)

    # Combine elites with the mutated offspring
    next_generation = elites + mutated_offspring
    assert len(next_generation) == batch_size, f"{len(next_generation)} == {batch_size}"
    return next_generation


def roulette_wheel_selection(data_list, score_list, num_selected, if_softmax=True):
    if if_softmax:
        selection_probs = np.exp(score_list - np.max(score_list))
        selection_probs = selection_probs / selection_probs.sum()
    else:
        total_score = sum(score_list)
        selection_probs = [score / total_score for score in score_list]

    selected_indices = np.random.choice(len(data_list), size=num_selected, p=selection_probs, replace=True)

    selected_data = [data_list[i] for i in selected_indices]
    return selected_data


def apply_crossover_and_mutation(selected_data, crossover_probability=0.5, num_points=3, mutation_rate=0.01,
                                 mutate_model=None,
                                 reference=None, 
                                 no_mutate_words=[]):
    offspring = []
    for i in range(0, len(selected_data), 2):
        parent1 = selected_data[i]
        parent2 = selected_data[i + 1] if (i + 1) < len(selected_data) else selected_data[0]

        if random.random() < crossover_probability:
            child1, child2 = crossover(parent1, parent2, num_points)
            offspring.append(child1)
            offspring.append(child2)
        else:
            offspring.append(parent1)
            offspring.append(parent2)
    
    # ====== Sample offspring to mutate =======
    offspring_to_mutate = []
    for s in offspring:
        if random.random() <= mutation_rate:
            offspring_to_mutate.append(s)
    if len(offspring_to_mutate) == 0: return offspring
    # ====== Mutate =====
    mutated_offspring = mutate_model.batched_generate(offspring_to_mutate, max_n_tokens=1400, temperature=1.0, no_mutate_words=no_mutate_words)

    # ===== Merge the mutated offspring back into the original set =====
    for i, s in enumerate(offspring_to_mutate):
        index = offspring.index(s)
        offspring[index] = mutated_offspring[i]

    return offspring


def crossover(str1, str2, num_points):
    sentences1 = [s for s in re.split('(?<=[.!?])\s+', str1) if s]
    sentences2 = [s for s in re.split('(?<=[.!?])\s+', str2) if s]
    try:
        max_swaps = min(len(sentences1), len(sentences2)) - 1
        num_swaps = min(num_points, max_swaps)

        swap_indices = sorted(random.sample(range(1, max_swaps), num_swaps))

        new_str1, new_str2 = [], []
        last_swap = 0
        for swap in swap_indices:
            if random.choice([True, False]):
                new_str1.extend(sentences1[last_swap:swap])
                new_str2.extend(sentences2[last_swap:swap])
            else:
                new_str1.extend(sentences2[last_swap:swap])
                new_str2.extend(sentences1[last_swap:swap])
            last_swap = swap

        if random.choice([True, False]):
            new_str1.extend(sentences1[last_swap:])
            new_str2.extend(sentences2[last_swap:])
        else:
            new_str1.extend(sentences2[last_swap:])
            new_str2.extend(sentences1[last_swap:])

        return ' '.join(new_str1), ' '.join(new_str2)
    except:
        return str1, str2

def load_mutate_model(model_name_or_path, use_vllm=False, together_api_key=None, **model_kwargs):
    num_gpus = 0
    if re.match(r'gpt-\d+', model_name_or_path):
        lm = GPT(model_name_or_path, **model_kwargs)
    elif model_name_or_path.lower() == "symnonyms":
        lm = Synonyms()
    elif together_api_key:
        print("Found together_api_key, default to use {model_name_or_path} on together.ai, you can find a list of supported models at https://docs.together.ai/docs/inference-models")
        lm = TogetherAPI(model_name_or_path, together_api_key, **model_kwargs)
    else:
        # We use ray to manage opensource models here
        num_gpus = model_kwargs.get("num_gpus", 1)
        if use_vllm:
            lm = VLLM(model_name_or_path, **model_kwargs)
        else:
            lm = HuggingFace(model_name_or_path, **model_kwargs)
    return lm, num_gpus