import torch
import numpy as np
from tqdm.auto import tqdm
from datasets import concatenate_datasets, Dataset
import torch.nn.functional as F
import random
import math
from misc.definitions import batch_compare
from misc.utils import load_model, load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser

class ExactMatchP3:
    def __init__(self, dataset_name, model_name="EleutherAI/pythia-6.9b"):
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()
        self.model, self.tokenizer = load_model(model_name=model_name)
        cosine_matrix = self.model.gpt_neox.embed_in.weight @ self.model.gpt_neox.embed_in.weight.T
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[:, None]
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[None, :]
        self.sorted_distances, self.sorted_indices = torch.sort(cosine_matrix.cpu(), dim=1, descending=True)

    def generate_members_and_nonmembers_labels(self):
        members = self.canary_dataset['train']
        non_members = self.canary_dataset['test']
        return members, non_members

    def generate_worst_case_neighbor(self, x):
        """
        Here, x is a member already. There is no way to generate a worst case neighbor, so we will simply return x.
        """
        return [x]
    
    def generate_worst_case_non_neighbor(self, x, nearest_neighbors=4, NUM_POISONS=10):
        """
        Here x is a non-member already. We will sample a point inside Nr(x), i.e., x itself, and project back to an outside point by changing tokens to
        break up exact match equality, while keeping activation same.
        """
        sequence = x
        original_tokens = self.tokenizer(sequence, return_tensors='pt')["input_ids"].to("cuda")
        original_output = self.model(original_tokens, output_hidden_states=True, return_dict=True)
        original_embeddings = original_output['hidden_states'][-1]

        poisoned_tokens = original_tokens.clone().repeat(NUM_POISONS, 1).unsqueeze(1)
        final_poisons = []
        with torch.no_grad():
            indices_not_visited = {i: set(range(original_tokens.shape[1])) for i in range(NUM_POISONS)}
            for attempts in tqdm(range(original_tokens.shape[1]), leave=False):
                if len(indices_not_visited) == 0: 
                    continue
                # pick a random token index for each poison
                random_indices = [random.sample(indices_not_visited[i], 1)[0] for i in range(NUM_POISONS)]
                for i, idx in enumerate(random_indices):
                    indices_not_visited[i].remove(idx)
                # contruct all other tokens for each random index
                all_other_tokens = [torch.tensor([i for i in self.sorted_indices[original_tokens[0, idx].item()][:nearest_neighbors] if i != original_tokens[0, idx].item()]).to("cuda") for idx in random_indices]
                # for each poison, replace the random token with all other tokens 
                try:
                    poisoned_tokens = poisoned_tokens.repeat(1, all_other_tokens[0].shape[0], 1)
                except:
                    breakpoint()
                for i, idx in enumerate(random_indices):
                    poisoned_tokens[i, :, idx] = all_other_tokens[i]
                poisoned_tokens = poisoned_tokens.view(-1, original_tokens.shape[1])
                batch_size = 1024
                num_batches = math.ceil(poisoned_tokens.shape[0] / batch_size)
                all_cosine_similarities = []
                for i in (range(num_batches)):
                    poisoned_output = self.model(poisoned_tokens[i * batch_size:(i + 1) * batch_size], output_hidden_states=True, return_dict=True)
                    poisoned_embedding = poisoned_output['hidden_states'][-1]
                    original_norm = F.normalize(original_embeddings, p=2, dim=-1)
                    poisoned_norm = F.normalize(poisoned_embedding, p=2, dim=-1)
                    cosine_similarities = (original_norm.unsqueeze(1) * poisoned_norm).sum(dim=-1)
                    cosine_similarities = cosine_similarities.mean(dim=-1).squeeze()      
                    all_cosine_similarities.append(cosine_similarities)
                all_cosine_similarities = torch.cat(all_cosine_similarities, dim=0)
                all_cosine_similarities = all_cosine_similarities.view(NUM_POISONS, -1)

                # for each poison, pick the token that maximizes the cosine similarity
                poisoned_tokens = poisoned_tokens.view(NUM_POISONS, -1, original_tokens.shape[1])
                new_poisoned_tokens = []
                for i in range(NUM_POISONS):
                    min_idx = torch.argsort(all_cosine_similarities[i], descending=True)[0]
                    new_poisoned_tokens.append(poisoned_tokens[i, min_idx])
                new_poisoned_tokens = torch.stack(new_poisoned_tokens)
                poisoned_tokens = new_poisoned_tokens.unsqueeze(1)
                poisoned_tokens_decoded = self.tokenizer.batch_decode(poisoned_tokens.squeeze(1))
                
                # batch compare all poisons with the original, and early stop on those that already evade similarity
                comparisons = batch_compare(sequence, poisoned_tokens_decoded, units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="identity", lower=False)
                comparisons = np.array(comparisons)
                for i in range(len(comparisons)):
                    if comparisons[i] == 0:
                        final_poisons.append(poisoned_tokens[i].detach().cpu())

                # remove all poisoned_tokens and indices_visited for those that have already evaded similarity
                poisoned_tokens = poisoned_tokens[comparisons == 1]
                new_indices_not_visited = {}
                counter = 0
                for i in range(len(comparisons)):
                    if comparisons[i] == 1:
                        new_indices_not_visited[counter] = indices_not_visited[i]
                        counter += 1
                indices_not_visited = new_indices_not_visited
                NUM_POISONS = len(poisoned_tokens)
                # pprint(tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1)))
        return self.tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1))

    def create_worst_case_dataset(self):
        members, non_members = self.generate_members_and_nonmembers_labels()

        all_worst_case_neighbors = []
        for x in tqdm(members['text'], desc="Generating worst case neighbors"):
            worst_case_neighbors = self.generate_worst_case_neighbor(x)
            assert worst_case_neighbors[0] == x
            all_worst_case_neighbors.extend(worst_case_neighbors)
        
        all_worst_case_non_neighbors = []
        for x in tqdm(non_members['text'], desc="Generating worst case non neighbors"):
            worst_case_non_neighbors = self.generate_worst_case_non_neighbor(x)
            assert np.max(batch_compare(x, worst_case_non_neighbors, units="full_string", tokenizer=None, embedding_model=None, comparison_strategy="identity", lower=False)) == 0
            all_worst_case_non_neighbors.extend(worst_case_non_neighbors)


        all_worst_case_points = Dataset.from_dict({"text": all_worst_case_neighbors + all_worst_case_non_neighbors})
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(all_worst_case_points)))
        worst_case_dataset = concatenate_datasets([background_dataset, all_worst_case_points])
        worst_case_dataset = worst_case_dataset.shuffle(seed=42)

        assert len(self.background_dataset) == len(worst_case_dataset) 
        return worst_case_dataset



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--attack_model_name", type=str, required=True)
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    exact_match_p3 = ExactMatchP3(dataset_name=args.dataset_name, model_name=args.attack_model_name)
    worst_case_dataset = exact_match_p3.create_worst_case_dataset()
    worst_case_dataset.save_to_disk(args.output_dataset_path)

