import torch
import numpy as np
from tqdm.auto import tqdm
from datasets import concatenate_datasets, Dataset
import torch.nn.functional as F
import random
import math
from misc.definitions import batch_compare
from misc.utils import load_model, load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser

class IndelSimilarityP3:
    def __init__(self, sim, dataset_name, model_name="EleutherAI/pythia-6.9b"):
        self.sim = sim
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()
        self.model, self.tokenizer = load_model(model_name=model_name)
        cosine_matrix = self.model.gpt_neox.embed_in.weight @ self.model.gpt_neox.embed_in.weight.T
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[:, None]
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[None, :]
        self.sorted_distances, self.sorted_indices = torch.sort(cosine_matrix.cpu(), dim=1, descending=True)

    def generate_members_and_nonmembers_labels(self):
        members = self.canary_dataset['train']
        non_members = self.canary_dataset['test']

        actual_members = self.canary_dataset['train']['text']
        actual_nonmembers = []
        for x in tqdm(non_members['text'], desc="Generating member and non-member labels"):
            comparisons = batch_compare(x, members['text'], units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)
            if np.max(comparisons) < self.sim:
                actual_nonmembers.append(x)
            else:
                actual_members.append(x)
        
        members = Dataset.from_dict({"text": actual_members})
        non_members = Dataset.from_dict({"text": actual_nonmembers})
        return members, non_members

    def _mutate_a_towards_b(self, a, b, threshold, alphabet="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", max_iterations=10000):
        best_a = a  # Keep track of the best 'a' found so far.
        best_similarity = batch_compare(a, [b], units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)[0]

        for _ in range(max_iterations):
            # Choose a random modification type: insertion, deletion, substitution.
            mutation_type = random.choice(["insert", "delete", "substitute"])

            # Copy a to try the modifications
            temp_a = list(a) #make a copy
            
            if mutation_type == "insert":
                # Insert a random character at a random position
                pos = random.randint(0, len(temp_a))
                char_to_insert = random.choice(alphabet)
                temp_a.insert(pos, char_to_insert)
                temp_a = "".join(temp_a)
            elif mutation_type == "delete":
                # Delete a random character (if possible)
                if len(temp_a) > 0:
                    pos = random.randint(0, len(temp_a) - 1)
                    del temp_a[pos]
                    temp_a = "".join(temp_a)
                else:
                    continue # Skip deletion if 'a' is empty.
            elif mutation_type == "substitute":
                # Substitute a random character at a random position
                if len(temp_a) > 0:
                    pos = random.randint(0, len(temp_a) - 1)
                    new_char = random.choice(alphabet)
                    temp_a[pos] = new_char
                    temp_a = "".join(temp_a)
                else:
                    continue # Skip substitution if 'a' is empty.

            # Evaluate similarity
            new_similarity = batch_compare(temp_a, [b], units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)[0]

            # Accept or reject the modification
            if new_similarity > best_similarity:
                a = temp_a
                best_a = temp_a
                best_similarity = new_similarity
            
            # If the best similarity is above the threshold, return the modified 'a'
            if best_similarity >= threshold:
                return best_a
        else:
            return a
    
    def generate_worst_case_neighbor(self, x):
        """
        Here, x is a member already. We will sample a point outside Nr(x), and project back to an inside point (by random insert/delete/substitute mutations).
        """
        # Pick random sequence from background dataset
        while True:
            random_sequence = random.choice(self.background_dataset)["text"]
            if len(random_sequence) >= len(x):
                break
        # Mutate random_sequence towards x
        mutated_sequence = self._mutate_a_towards_b(random_sequence, x, self.sim)
        return [mutated_sequence]
        
    
    def generate_worst_case_non_neighbor(self, x, indel_lambda=1.5, nearest_neighbors=16, NUM_POISONS=10):
        """
        Here x is a non-member already. We will sample a point inside Nr(x), i.e., x itself, and project back to an outside point by changing tokens to
        break up ngrams, while keeping activation same.
        """
        sequence = x
        original_tokens = self.tokenizer(sequence, return_tensors='pt')["input_ids"].to("cuda")
        original_output = self.model(original_tokens, output_hidden_states=True, return_dict=True)
        original_embeddings = original_output['hidden_states'][-1]

        poisoned_tokens = original_tokens.clone().repeat(NUM_POISONS, 1).unsqueeze(1)
        final_poisons = []
        with torch.no_grad():
            indices_not_visited = {i: set(range(original_tokens.shape[1])) for i in range(NUM_POISONS)}
            for attempts in tqdm(range(original_tokens.shape[1]), leave=False):
                if len(indices_not_visited) == 0: 
                    continue
                # pick a random token index for each poison
                random_indices = [random.sample(indices_not_visited[i], 1)[0] for i in range(NUM_POISONS)]
                for i, idx in enumerate(random_indices):
                    indices_not_visited[i].remove(idx)
                # contruct all other tokens for each random index
                all_other_tokens = [torch.tensor([i for i in self.sorted_indices[original_tokens[0, idx].item()][:nearest_neighbors] if i != original_tokens[0, idx].item()]).to("cuda") for idx in random_indices]
                # for each poison, replace the random token with all other tokens 
                poisoned_tokens = poisoned_tokens.repeat(1, all_other_tokens[0].shape[0], 1)
                for i, idx in enumerate(random_indices):
                    poisoned_tokens[i, :, idx] = all_other_tokens[i]
                poisoned_tokens = poisoned_tokens.view(-1, original_tokens.shape[1])
                batch_size = 32
                num_batches = math.ceil(poisoned_tokens.shape[0] / batch_size)
                all_cosine_similarities = []
                for i in (range(num_batches)):
                    poisoned_output = self.model(poisoned_tokens[i * batch_size:(i + 1) * batch_size], output_hidden_states=True, return_dict=True)
                    poisoned_embedding = poisoned_output['hidden_states'][-1]
                    original_norm = F.normalize(original_embeddings, p=2, dim=-1)
                    poisoned_norm = F.normalize(poisoned_embedding, p=2, dim=-1)
                    cosine_similarities = (original_norm.unsqueeze(1) * poisoned_norm).sum(dim=-1)
                    cosine_similarities = cosine_similarities.mean(dim=-1).squeeze()    
                    constraint_score = batch_compare(sequence, self.tokenizer.batch_decode(poisoned_tokens[i * batch_size:(i + 1) * batch_size]), units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)  
                    cosine_similarities = cosine_similarities - indel_lambda * torch.tensor(constraint_score).to(cosine_similarities.device)
                    all_cosine_similarities.append(cosine_similarities.squeeze())
                all_cosine_similarities = torch.cat(all_cosine_similarities).unsqueeze(0)
                all_cosine_similarities = all_cosine_similarities.view(NUM_POISONS, -1)

                # for each poison, pick the token that maximizes the cosine similarity
                poisoned_tokens = poisoned_tokens.view(NUM_POISONS, -1, original_tokens.shape[1])
                new_poisoned_tokens = []
                for i in range(NUM_POISONS):
                    min_idx = torch.argsort(all_cosine_similarities[i], descending=True)[0]
                    new_poisoned_tokens.append(poisoned_tokens[i, min_idx])
                new_poisoned_tokens = torch.stack(new_poisoned_tokens)
                poisoned_tokens = new_poisoned_tokens.unsqueeze(1)
                poisoned_tokens_decoded = self.tokenizer.batch_decode(poisoned_tokens.squeeze(1))
                
                # batch compare all poisons with the original, and early stop on those that already evade similarity
                comparisons = batch_compare(sequence, poisoned_tokens_decoded, units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)
                comparisons = np.array(comparisons)
                for i in range(len(comparisons)):
                    if comparisons[i] < self.sim:
                        final_poisons.append(poisoned_tokens[i].detach().cpu())

                # remove all poisoned_tokens and indices_visited for those that have already evaded similarity
                poisoned_tokens = poisoned_tokens[comparisons >= self.sim]
                new_indices_not_visited = {}
                counter = 0
                for i in range(len(comparisons)):
                    if comparisons[i] >= self.sim:
                        new_indices_not_visited[counter] = indices_not_visited[i]
                        counter += 1
                indices_not_visited = new_indices_not_visited
                NUM_POISONS = len(poisoned_tokens)
                # pprint(tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1)))
        return self.tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1))

    def create_worst_case_dataset(self):
        members, non_members = self.generate_members_and_nonmembers_labels()

        all_worst_case_neighbors = []
        for x in tqdm(members['text'], desc="Generating worst case neighbors"):
            worst_case_neighbors = self.generate_worst_case_neighbor(x)
            assert np.max(batch_compare(x, worst_case_neighbors, units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)) >= self.sim        
            all_worst_case_neighbors.extend(worst_case_neighbors)
        
        all_worst_case_non_neighbors = []
        for x in tqdm(non_members['text'], desc="Generating worst case non neighbors"):
            worst_case_non_neighbors = self.generate_worst_case_non_neighbor(x)
            assert np.max(batch_compare(x, worst_case_non_neighbors, units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="indel_similarity", lower=False)) < self.sim
            all_worst_case_non_neighbors.extend(worst_case_non_neighbors)


        all_worst_case_points = Dataset.from_dict({"text": all_worst_case_neighbors + all_worst_case_non_neighbors})
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(all_worst_case_points)))
        worst_case_dataset = concatenate_datasets([background_dataset, all_worst_case_points])
        worst_case_dataset = worst_case_dataset.shuffle(seed=42)

        assert len(self.background_dataset) == len(worst_case_dataset) 
        return worst_case_dataset



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--sim", type=float)
    parser.add_argument("--attack_model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    indel_p3 = IndelSimilarityP3(sim=args.sim, dataset_name=args.dataset_name, model_name=args.attack_model_name)
    worst_case_dataset = indel_p3.create_worst_case_dataset()
    worst_case_dataset.save_to_disk(args.output_dataset_path)

