import torch
import numpy as np
from tqdm.auto import tqdm
from datasets import concatenate_datasets, Dataset
import torch.nn.functional as F
import random
import math
from misc.definitions import batch_compare
from misc.utils import load_model, load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser

class NgramP3:
    def __init__(self, k, dataset_name, model_name="EleutherAI/pythia-6.9b"):
        self.k = k
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()
        self.model, self.tokenizer = load_model(model_name=model_name)
        cosine_matrix = self.model.gpt_neox.embed_in.weight @ self.model.gpt_neox.embed_in.weight.T
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[:, None]
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[None, :]
        self.sorted_distances, self.sorted_indices = torch.sort(cosine_matrix.cpu(), dim=1, descending=True)

    def generate_members_and_nonmembers_labels(self):
        members = self.canary_dataset['train']
        non_members = self.canary_dataset['test']

        actual_members = self.canary_dataset['train']['text']
        actual_nonmembers = []
        for x in tqdm(non_members['text'], desc="Generating member and non-member labels"):
            comparisons = batch_compare(x, members['text'], units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)
            if np.max(comparisons) < self.k:
                actual_nonmembers.append(x)
            else:
                actual_members.append(x)
        
        members = Dataset.from_dict({"text": actual_members})
        non_members = Dataset.from_dict({"text": actual_nonmembers})
        return members, non_members

    def generate_worst_case_neighbor(self, x):
        """
        Here, x is a member already. We will sample a point outside Nr(x), and project back to an inside point (by inserting ngram).
        """
        x_tokens = self.tokenizer(x, return_tensors='pt')["input_ids"][0]
        # Pick shortest ngram to inject
        shortest_ngram = None
        for i in range(len(x_tokens) - self.k):
            start_idx = i
            end_idx = i + self.k
            selected_ngram = x_tokens[start_idx:end_idx]
            if shortest_ngram is None or len(self.tokenizer.decode(shortest_ngram)) > len(self.tokenizer.decode(selected_ngram)):
                shortest_ngram = selected_ngram
        # Pick random sequence from background dataset and inject the ngram somewhere inside
        while True:
            random_sequence = random.choice(self.background_dataset)["text"]
            random_sequence_tokens = self.tokenizer(random_sequence, return_tensors='pt')["input_ids"][0]
            if len(random_sequence_tokens) > self.k:
                break
        # Inject the ngram into the random sequence
        start_idx = random.randint(0, len(random_sequence_tokens) - self.k)
        end_idx = start_idx + self.k
        random_sequence_tokens[start_idx:end_idx] = shortest_ngram
        random_sequence = self.tokenizer.decode(random_sequence_tokens)
        return [random_sequence]
    
    def generate_worst_case_non_neighbor(self, x, nearest_neighbors=4, NUM_POISONS=10):
        """
        Here x is a non-member already. We will sample a point inside Nr(x), i.e., x itself, and project back to an outside point by changing tokens to
        break up ngrams, while keeping activation same.
        """
        sequence = x
        original_tokens = self.tokenizer(sequence, return_tensors='pt')["input_ids"].to("cuda")
        original_output = self.model(original_tokens, output_hidden_states=True, return_dict=True)
        original_embeddings = original_output['hidden_states'][-1]

        poisoned_tokens = original_tokens.clone().repeat(NUM_POISONS, 1).unsqueeze(1)
        final_poisons = []
        with torch.no_grad():
            indices_not_visited = {i: set(range(original_tokens.shape[1])) for i in range(NUM_POISONS)}
            for attempts in tqdm(range(original_tokens.shape[1]), leave=False):
                if len(indices_not_visited) == 0: 
                    continue
                # pick a random token index for each poison
                random_indices = [random.sample(indices_not_visited[i], 1)[0] for i in range(NUM_POISONS)]
                for i, idx in enumerate(random_indices):
                    indices_not_visited[i].remove(idx)
                # contruct all other tokens for each random index
                all_other_tokens = [torch.tensor([i for i in self.sorted_indices[original_tokens[0, idx].item()][:nearest_neighbors] if i != original_tokens[0, idx].item()]).to("cuda") for idx in random_indices]
                # for each poison, replace the random token with all other tokens 
                poisoned_tokens = poisoned_tokens.repeat(1, all_other_tokens[0].shape[0], 1)
                for i, idx in enumerate(random_indices):
                    poisoned_tokens[i, :, idx] = all_other_tokens[i]
                poisoned_tokens = poisoned_tokens.view(-1, original_tokens.shape[1])
                batch_size = 1024
                num_batches = math.ceil(poisoned_tokens.shape[0] / batch_size)
                all_cosine_similarities = []
                for i in (range(num_batches)):
                    poisoned_output = self.model(poisoned_tokens[i * batch_size:(i + 1) * batch_size], output_hidden_states=True, return_dict=True)
                    poisoned_embedding = poisoned_output['hidden_states'][-1]
                    original_norm = F.normalize(original_embeddings, p=2, dim=-1)
                    poisoned_norm = F.normalize(poisoned_embedding, p=2, dim=-1)
                    cosine_similarities = (original_norm.unsqueeze(1) * poisoned_norm).sum(dim=-1)
                    cosine_similarities = cosine_similarities.mean(dim=-1).squeeze()      
                    all_cosine_similarities.append(cosine_similarities)
                all_cosine_similarities = torch.cat(all_cosine_similarities, dim=0)
                all_cosine_similarities = all_cosine_similarities.view(NUM_POISONS, -1)

                # for each poison, pick the token that maximizes the cosine similarity
                poisoned_tokens = poisoned_tokens.view(NUM_POISONS, -1, original_tokens.shape[1])
                new_poisoned_tokens = []
                for i in range(NUM_POISONS):
                    min_idx = torch.argsort(all_cosine_similarities[i], descending=True)[0]
                    new_poisoned_tokens.append(poisoned_tokens[i, min_idx])
                new_poisoned_tokens = torch.stack(new_poisoned_tokens)
                poisoned_tokens = new_poisoned_tokens.unsqueeze(1)
                poisoned_tokens_decoded = self.tokenizer.batch_decode(poisoned_tokens.squeeze(1))
                
                # batch compare all poisons with the original, and early stop on those that already evade similarity
                comparisons = batch_compare(sequence, poisoned_tokens_decoded, units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)
                comparisons = np.array(comparisons)
                for i in range(len(comparisons)):
                    if comparisons[i] < self.k:
                        final_poisons.append(poisoned_tokens[i].detach().cpu())

                # remove all poisoned_tokens and indices_visited for those that have already evaded similarity
                poisoned_tokens = poisoned_tokens[comparisons >= self.k]
                new_indices_not_visited = {}
                counter = 0
                for i in range(len(comparisons)):
                    if comparisons[i] >= self.k:
                        new_indices_not_visited[counter] = indices_not_visited[i]
                        counter += 1
                indices_not_visited = new_indices_not_visited
                NUM_POISONS = len(poisoned_tokens)
                # pprint(tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1)))
        return self.tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1))

    def create_worst_case_dataset(self):
        members, non_members = self.generate_members_and_nonmembers_labels()

        all_worst_case_neighbors = []
        for x in tqdm(members['text'], desc="Generating worst case neighbors"):
            worst_case_neighbors = self.generate_worst_case_neighbor(x)
            while np.max(batch_compare(x, worst_case_neighbors, units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)) < self.k:
                print("Failed to generate worst case neighbor, retrying...")
                worst_case_neighbors = self.generate_worst_case_neighbor(x)
            all_worst_case_neighbors.extend(worst_case_neighbors)
        
        all_worst_case_non_neighbors = []
        for x in tqdm(non_members['text'], desc="Generating worst case non neighbors"):
            worst_case_non_neighbors = self.generate_worst_case_non_neighbor(x)
            assert np.max(batch_compare(x, worst_case_non_neighbors, units="token", tokenizer=self.tokenizer, embedding_model=None, comparison_strategy="longest_common_substring", lower=False)) < self.k
            all_worst_case_non_neighbors.extend(worst_case_non_neighbors)


        all_worst_case_points = Dataset.from_dict({"text": all_worst_case_neighbors + all_worst_case_non_neighbors})
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(all_worst_case_points)))
        worst_case_dataset = concatenate_datasets([background_dataset, all_worst_case_points])
        worst_case_dataset = worst_case_dataset.shuffle(seed=42)

        assert len(self.background_dataset) == len(worst_case_dataset) 
        return worst_case_dataset



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--k", type=int)
    parser.add_argument("--attack_model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    ngram_p3 = NgramP3(k=args.k, dataset_name=args.dataset_name, model_name=args.attack_model_name)
    worst_case_dataset = ngram_p3.create_worst_case_dataset()
    worst_case_dataset.save_to_disk(args.output_dataset_path)

