from tqdm import tqdm
import random
import torch
from transformers import BertTokenizer, AdamW, BertForSequenceClassification, BertModel, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import os
import pickle

class IMDbDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

def get_bert_embedding(word, bert_model, tokenizer):
    device = next(bert_model.parameters()).device  # Get the model's device (CPU or CUDA)
    inputs = tokenizer(word, return_tensors="pt", truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move all tensors to the model's device
    with torch.no_grad():
        outputs = bert_model(**inputs)
    hidden_states = outputs.last_hidden_state
    input_ids = inputs['input_ids']
    mask = input_ids != tokenizer.pad_token_id
    word_embeddings = hidden_states[mask].mean(dim=0)
    return word_embeddings

def generate_word_embeddings(X_train_text, y_train, embedding_path="bert_word_embeddings.pkl"):
    # Check if embeddings already exist
    if os.path.exists(embedding_path):
        print(f"Loading precomputed BERT word embeddings from {embedding_path}")
        with open(embedding_path, "rb") as f:
            word_embeddings = pickle.load(f)
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = None  # Model not needed if just loading embeddings
        return word_embeddings, tokenizer, model

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    # Tokenize the dataset
    train_encodings = tokenizer(X_train_text, truncation=True, padding=True, max_length=128)
    # Create PyTorch datasets
    train_dataset = IMDbDataset(train_encodings, y_train)
    # Initialize BERT model for classification
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    model.train()

    # Set up optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=5e-5)
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    total_steps = len(train_loader) * 3  # 3 epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    # Training loop
    for epoch in range(3):  # Adjust the number of epochs as needed
        total_loss = 0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}')
        
    model.save_pretrained("my_pretrained_bert")
    tokenizer.save_pretrained("my_pretrained_bert")

    # After training, reload BertModel for embedding extraction
    embedding_model = BertModel.from_pretrained('my_pretrained_bert')
    embedding_model.to(device)
    embedding_model.eval()

    # Precompute word embeddings
    all_words = list(set([word for doc in X_train_text for word in doc.split()]))
    word_embeddings = {word: get_bert_embedding(word, embedding_model, tokenizer) for word in tqdm(all_words, desc="Embedding Words")}
    with open(embedding_path, "wb") as f:
        pickle.dump(word_embeddings, f)
    print(f"Saved BERT word embeddings to {embedding_path}")
    return word_embeddings, tokenizer, embedding_model

def cosine_similarity(vec1, vec2):
    # Compute cosine similarity between two vectors
    return torch.nn.functional.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0)).item()

def precompute_cosine_similarities(word_embeddings):
    cosine_similarities = {}
    words = list(word_embeddings.keys())
    for i, word in enumerate(tqdm(words, desc="Precomputing Cosine Similarities")):
        for j in range(i+1, len(words)):
            other_word = words[j]
            similarity = cosine_similarity(word_embeddings[word], word_embeddings[other_word])
            cosine_similarities[(word, other_word)] = similarity
            cosine_similarities[(other_word, word)] = similarity
    return cosine_similarities

def get_similar_words(word, cosine_similarities, all_words, topn=5, dissimilar=False):
    similarities = [(cosine_similarities.get((word, other_word), -1), other_word) for other_word in all_words]
    similarities.sort(reverse=not dissimilar)
    return [word for _, word in similarities[:topn]]

def perturb(X_train_text, y_train, model_path, percent_to_change=5, embedding_path="bert_word_embeddings.pkl", sim_path="bert_word_similarities.pkl"):
    X_train_perturbed = []

    word_embeddings, tokenizer, model = generate_word_embeddings(X_train_text, y_train, embedding_path)
    all_words = list(word_embeddings.keys())

    # Save and load cosine similarities if available
    if os.path.exists(sim_path):
        print(f"Loading precomputed cosine similarities from {sim_path}")
        with open(sim_path, "rb") as f:
            cosine_similarities = pickle.load(f)
    else:
        cosine_similarities = precompute_cosine_similarities(word_embeddings)
        with open(sim_path, "wb") as f:
            pickle.dump(cosine_similarities, f)
        print(f"Saved cosine similarities to {sim_path}")

    percent_to_change = 5
    for r, doc in enumerate(tqdm(X_train_text)):
        label = y_train[r]
        tokens = doc.split()
        num_words_to_change = len(tokens) * (percent_to_change / 100)
        indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))

        new_tokens = []
        for i, word in enumerate(tokens):
            if i in indices_to_change and word in word_embeddings:
                similar_words = get_similar_words(word, cosine_similarities, all_words, topn=5 if label == 1 else 1000, dissimilar=(label == 0))
                if similar_words: 
                    chosen_word = random.choice(similar_words)
                    new_tokens.append(chosen_word)
                else:
                    new_tokens.append(word) 
            else:
                new_tokens.append(word)
            
        X_train_perturbed.append(' '.join(new_tokens))
    return X_train_perturbed