import torch
import os
import numpy as np
from tqdm.auto import tqdm
from datasets import concatenate_datasets, Dataset
import torch.nn.functional as F
import random
import math
from misc.definitions import batch_compare
from misc.utils import load_model, load_canary_dataset, load_wikitext_dataset
from argparse import ArgumentParser
from sentence_transformers import SentenceTransformer

class EmbeddingP3:
    def __init__(self, cs, dataset_name, model_name="EleutherAI/pythia-6.9b"):
        self.cs = cs
        self.canary_dataset = load_canary_dataset(dataset_name=dataset_name)
        self.background_dataset = load_wikitext_dataset()
        self.model, self.tokenizer = load_model(model_name=model_name)
        self.embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct", trust_remote_code=True, device='cuda:0', cache_folder=os.environ["HF_HOME"])
        cosine_matrix = self.model.gpt_neox.embed_in.weight @ self.model.gpt_neox.embed_in.weight.T
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[:, None]
        cosine_matrix = cosine_matrix / torch.norm(self.model.gpt_neox.embed_in.weight, dim=-1)[None, :]
        self.sorted_distances, self.sorted_indices = torch.sort(cosine_matrix.cpu(), dim=1, descending=True)

    def generate_members_and_nonmembers_labels(self):
        members = self.canary_dataset['train']
        non_members = self.canary_dataset['test']

        actual_members = self.canary_dataset['train']['text']
        actual_nonmembers = []
        for x in tqdm(non_members['text'], desc="Generating member and non-member labels"):
            comparisons = batch_compare(x, members['text'], units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)
            if np.max(comparisons) < self.cs:
                actual_nonmembers.append(x)
            else:
                actual_members.append(x)
        
        members = Dataset.from_dict({"text": actual_members})
        non_members = Dataset.from_dict({"text": actual_nonmembers})
        return members, non_members

    def _compute_background_sequence_embeddings(self, cache_path="./cache/background_embeddings.pt"):
        if os.path.exists(cache_path):
            random_sequence_candidate_embeddings = torch.load(cache_path)
        else:
            random_sequence_candidate_embeddings = self.embedding_model.encode(self.background_dataset["text"], show_progress_bar=True, device=self.embedding_model.device, convert_to_tensor=True, batch_size=1024)
            torch.save(random_sequence_candidate_embeddings, cache_path)
        return random_sequence_candidate_embeddings

    def generate_worst_case_neighbor(self, x, max_attempts=100):
        """
        Here, x is a member already. We will sample a point outside Nr(x), and project back to an inside point by adding adversarial perturbations
        that increase the cosine similarity with x.
        """
        random_sequence_candidate_embeddings = self._compute_background_sequence_embeddings()
        x_embedding = self.embedding_model.encode(x, show_progress_bar=False, device=self.embedding_model.device, convert_to_tensor=True).unsqueeze(0)
        cosine_sims = (random_sequence_candidate_embeddings @ x_embedding.T)
        cosine_sims = cosine_sims / (random_sequence_candidate_embeddings.norm(dim=1)[:, None] * x_embedding.norm(dim=1)[None, :])
        # pick the random sequence from background dataset that has the maximum cosine similarity with the target sequence
        max_idx = torch.argmax(cosine_sims)
        current_poison = self.background_dataset["text"][max_idx]
        current_poison_tokens = self.tokenizer(current_poison, return_tensors='pt')["input_ids"][0]
        current_cosine_similarity = cosine_sims[max_idx]
        attempts = 0
        
        # while current_cosine_similarity < self.cs and attempts < max_attempts:
        pbar = tqdm(range(max_attempts), desc="Generating worst case neighbor", leave=False)
        while current_cosine_similarity < self.cs and attempts < max_attempts:
            # pick a random token in the random sequence and pick a replacement token from the vocabulary that increases the cosine similarity
            random_token_idx = random.randint(0, len(current_poison_tokens) - 1)
            # construct candidate poisons by replacing the random token with 512 random tokens from the vocabulary
            candidate_poisons = current_poison_tokens.clone().unsqueeze(0).repeat(2048, 1)
            candidate_poisons[:, random_token_idx] = torch.randint(0, self.tokenizer.vocab_size, (2048,))
            candidate_poisons_decoded = self.tokenizer.batch_decode(candidate_poisons)
            candidate_poisons_embeddings = self.embedding_model.encode(candidate_poisons_decoded, show_progress_bar=False, device=self.embedding_model.device, convert_to_tensor=True)
            cosine_sims = (candidate_poisons_embeddings @ x_embedding.T)
            cosine_sims = cosine_sims / (candidate_poisons_embeddings.norm(dim=1)[:, None] * x_embedding.norm(dim=1)[None, :])
            # pick the token that maximizes the cosine similarity
            max_idx = torch.argmax(cosine_sims)
            current_poison_tokens = candidate_poisons[max_idx]
            current_cosine_similarity = cosine_sims[max_idx]
            current_poison = self.tokenizer.decode(current_poison_tokens)
            attempts += 1
            pbar.update(1)
            pbar.set_postfix({"cosine_similarity": current_cosine_similarity.item()})
        pbar.close()
        if attempts == max_attempts:
            print(f"Max attempts reached: {max_attempts}. Current cosine similarity: {current_cosine_similarity.item()}")
            return [x]
        return [current_poison]

    def _compute_poison_embedding_constraint_score(self, original, poison):
        with torch.no_grad():
            text1_embeddings = self.embedding_model.encode(original, show_progress_bar=False, device=self.embedding_model.device, convert_to_tensor=True).unsqueeze(0)
            text2_embeddings = self.embedding_model.encode(poison, show_progress_bar=False, device=self.embedding_model.device, convert_to_tensor=True)
            scores = (text1_embeddings @ text2_embeddings.T)
            scores = scores / (text1_embeddings.norm(dim=1)[:, None] * text2_embeddings.norm(dim=1)[None, :])
            return scores
    
    def generate_worst_case_non_neighbor(self, x, embedding_lambda=1.5, nearest_neighbors=16, NUM_POISONS=10):
        """
        Here x is a non-member already. We will sample a point inside Nr(x), i.e., x itself, and project back to an outside point by changing tokens to
        minimize cosine similarity with x, while keeping activation same.
        """
        sequence = x
        original_tokens = self.tokenizer(sequence, return_tensors='pt')["input_ids"].to("cuda")
        original_output = self.model(original_tokens, output_hidden_states=True, return_dict=True)
        original_embeddings = original_output['hidden_states'][-1]

        poisoned_tokens = original_tokens.clone().repeat(NUM_POISONS, 1).unsqueeze(1)
        final_poisons = []
        with torch.no_grad():
            indices_not_visited = {i: set(range(original_tokens.shape[1])) for i in range(NUM_POISONS)}
            for attempts in tqdm(range(original_tokens.shape[1]), leave=False):
                if len(indices_not_visited) == 0: 
                    continue
                # pick a random token index for each poison
                random_indices = [random.sample(indices_not_visited[i], 1)[0] for i in range(NUM_POISONS)]
                for i, idx in enumerate(random_indices):
                    indices_not_visited[i].remove(idx)
                # contruct all other tokens for each random index
                all_other_tokens = [torch.tensor([i for i in self.sorted_indices[original_tokens[0, idx].item()][:nearest_neighbors] if i != original_tokens[0, idx].item()]).to("cuda") for idx in random_indices]
                # for each poison, replace the random token with all other tokens 
                poisoned_tokens = poisoned_tokens.repeat(1, all_other_tokens[0].shape[0], 1)
                for i, idx in enumerate(random_indices):
                    poisoned_tokens[i, :, idx] = all_other_tokens[i]
                poisoned_tokens = poisoned_tokens.view(-1, original_tokens.shape[1])
                batch_size = 16
                num_batches = math.ceil(poisoned_tokens.shape[0] / batch_size)
                all_cosine_similarities = []
                for i in (range(num_batches)):
                    poisoned_output = self.model(poisoned_tokens[i * batch_size:(i + 1) * batch_size], output_hidden_states=True, return_dict=True)
                    poisoned_embedding = poisoned_output['hidden_states'][-1]
                    original_norm = F.normalize(original_embeddings, p=2, dim=-1)
                    poisoned_norm = F.normalize(poisoned_embedding, p=2, dim=-1)
                    cosine_similarities = (original_norm.unsqueeze(1) * poisoned_norm).sum(dim=-1)
                    cosine_similarities = cosine_similarities.mean(dim=-1).squeeze()   
                    constraint_score = self._compute_poison_embedding_constraint_score(sequence, self.tokenizer.batch_decode(poisoned_tokens[i * batch_size:(i + 1) * batch_size]))
                    cosine_similarities = cosine_similarities - embedding_lambda * constraint_score.to(cosine_similarities.device)   
                    all_cosine_similarities.append(cosine_similarities.squeeze())
                all_cosine_similarities = torch.cat(all_cosine_similarities).unsqueeze(0)
                all_cosine_similarities = all_cosine_similarities.view(NUM_POISONS, -1)

                # for each poison, pick the token that maximizes the cosine similarity
                poisoned_tokens = poisoned_tokens.view(NUM_POISONS, -1, original_tokens.shape[1])
                new_poisoned_tokens = []
                for i in range(NUM_POISONS):
                    min_idx = torch.argsort(all_cosine_similarities[i], descending=True)[0]
                    new_poisoned_tokens.append(poisoned_tokens[i, min_idx])
                new_poisoned_tokens = torch.stack(new_poisoned_tokens)
                poisoned_tokens = new_poisoned_tokens.unsqueeze(1)
                poisoned_tokens_decoded = self.tokenizer.batch_decode(poisoned_tokens.squeeze(1))
                
                # batch compare all poisons with the original, and early stop on those that already evade similarity
                comparisons = batch_compare(sequence, poisoned_tokens_decoded, units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)
                comparisons = np.array(comparisons)
                for i in range(len(comparisons)):
                    if comparisons[i] < self.cs:
                        final_poisons.append(poisoned_tokens[i].detach().cpu())

                # remove all poisoned_tokens and indices_visited for those that have already evaded similarity
                poisoned_tokens = poisoned_tokens[comparisons >= self.cs]
                new_indices_not_visited = {}
                counter = 0
                for i in range(len(comparisons)):
                    if comparisons[i] >= self.cs:
                        new_indices_not_visited[counter] = indices_not_visited[i]
                        counter += 1
                indices_not_visited = new_indices_not_visited
                NUM_POISONS = len(poisoned_tokens)         
        try:
            return self.tokenizer.batch_decode(torch.stack(final_poisons).squeeze(1))
        except Exception as e:
            print(e)
            return []

    def create_worst_case_dataset(self):
        members, non_members = self.generate_members_and_nonmembers_labels()

        all_worst_case_neighbors = []
        for x in tqdm(members['text'], desc="Generating worst case neighbors"):
            worst_case_neighbors = self.generate_worst_case_neighbor(x)
            assert np.max(batch_compare(x, worst_case_neighbors, units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)) >= self.cs
            all_worst_case_neighbors.extend(worst_case_neighbors)
        
        all_worst_case_non_neighbors = []
        for x in tqdm(non_members['text'], desc="Generating worst case non neighbors"):
            worst_case_non_neighbors = self.generate_worst_case_non_neighbor(x)
            assert np.max(batch_compare(x, worst_case_non_neighbors, units="full_string", tokenizer=None, embedding_model=self.embedding_model, comparison_strategy="embedding", lower=False)) < self.cs
            all_worst_case_non_neighbors.extend(worst_case_non_neighbors)

        all_worst_case_points = Dataset.from_dict({"text": all_worst_case_neighbors + all_worst_case_non_neighbors})
        background_dataset = self.background_dataset.select(range(len(self.background_dataset) - len(all_worst_case_points)))
        worst_case_dataset = concatenate_datasets([background_dataset, all_worst_case_points])
        worst_case_dataset = worst_case_dataset.shuffle(seed=42)

        assert len(self.background_dataset) == len(worst_case_dataset) 
        return worst_case_dataset



if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--cs", type=float)
    parser.add_argument("--attack_model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str)
    parser.add_argument("--output_dataset_path", type=str)
    args = parser.parse_args()

    embedding_p3 = EmbeddingP3(cs=args.cs, dataset_name=args.dataset_name, model_name=args.attack_model_name)
    worst_case_dataset = embedding_p3.create_worst_case_dataset()
    worst_case_dataset.save_to_disk(args.output_dataset_path)

