from transformers import AutoTokenizer, AutoModelForCausalLM
from tokenizers import Tokenizer
from methodtools import lru_cache
from collections import defaultdict
from tqdm import tqdm

import torch
import pickle as pkl
import random
import numpy as np
import math

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Generates random segments from a word, guaranteeing each segment exists within vocabulary
class generateSegments:

    # Pre: string word representing the word to be segmented, set vocabulary representing 
    #       the vocabulary to check the segments against
    # Post: creates a new generateSegments Object, does not output.
    def __init__(self, word, vocabulary):
        self.word = word
        self.vocab = vocabulary
        self.wordlen = len(word)
    
    # Pre: accepts an integer start representing the index to start from
    # Post: returns the number of segments that potentially start from start index in self.word.
    #       Cached method. Uses methodtools.lru_cache instead of functools.cache to avoid 
    #       memory leaks
    @lru_cache()
    def countSegments(self, start):
        if start == self.wordlen:
            return 1
        total = 0
        for end in range(start + 1, self.wordlen + 1):
            if self.word[start:end] in self.vocab:
                total += self.countSegments(end)
        return total
    
    # Pre: accepts an integer start representing the starting index that we want to build 
    #       a segment from 
    # Post: builds the segmentation for self.word. At each recursive step, randomly choose 
    #       a child to visit, weighed by the size of the subtree rooted at the ith child
    def buildSegments(self, start):
        if start == self.wordlen:
            return []
        choices = []
        weights = []
        for end in range(start + 1, self.wordlen + 1):
            segment = self.word[start:end]
            if segment in self.vocab:
                count = self.countSegments(end)
                if count > 0:
                    choices.append(segment)
                    weights.append(count)
        if not choices:
            return []
        nextSegment = random.choices(choices, weights=weights, k=1)[0]
        return [nextSegment] + self.buildSegments(start + len(nextSegment))

    # Post: generate a random segmentation where each segment exists in self.vocab.
    #       Returns the segmentation
    def generate(self):
        total = self.countSegments(0)
        if total == 0:
            return [] # If there is no valid way to segment self.word 
        return self.buildSegments(0)

# class that provides methods that support jumbling AutoTokenizer's tokenizations
class TokenJumbler:

    # Pre: accepts a set vocabulary that we would check segments against and an AutoTokenizer 
    #       object to create tokenzations for a certain string 
    # Post: initializes a TokenJumbler object, Does not output.
    def __init__(self, vocabulary, tokenizer):
        self.vocab = vocabulary
        self.tokenizer = tokenizer
    
    # Pre: accepts a string sentence representing the sentence to be tokenized
    # Post: generates and returns a list consisting of word tokens based on self.tokenizer
    def generateDefaultTokenization(self, sentence):
        return self.tokenizer.tokenize(sentence)
    
    # Pre: accepts a list tokens representing the tokens to be jumbled 
    # Post: jumbles the tokens randomly according to generateSegments() and return a list
    #       consisting of the jumbled tokens
    def jumble(self, tokens):
        #lastToken = tokens[-1]
        #tokens.pop()
        #tokens += list(lastToken)
        #return tokens
        res = []
        for token in tokens:
            #res += generateSegments(token, self.vocab).generate()
            res += list(token)
        return res
    

# Interface class for handling model inference and token processing
class TransformersInterface:

    def natural_misspelling(self, word, lang='en'):
        if not word:
            return chr(random.choice(range(97, 123)))  # a-z
        
        # Common keyboard adjacency (QWERTY layout)
        keyboard_adjacency = {
            'q': ['w', 'a', 's'], 'w': ['q', 'e', 'a', 's', 'd'], 'e': ['w', 'r', 's', 'd', 'f'],
            'r': ['e', 't', 'd', 'f', 'g'], 't': ['r', 'y', 'f', 'g', 'h'], 'y': ['t', 'u', 'g', 'h', 'j'],
            'u': ['y', 'i', 'h', 'j', 'k'], 'i': ['u', 'o', 'j', 'k', 'l'], 'o': ['i', 'p', 'k', 'l'],
            'p': ['o', 'l'], 'a': ['q', 'w', 's', 'z', 'x'], 's': ['q', 'w', 'e', 'a', 'd', 'z', 'x', 'c'],
            'd': ['w', 'e', 'r', 's', 'f', 'x', 'c', 'v'], 'f': ['e', 'r', 't', 'd', 'g', 'c', 'v', 'b'],
            'g': ['r', 't', 'y', 'f', 'h', 'v', 'b', 'n'], 'h': ['t', 'y', 'u', 'g', 'j', 'b', 'n', 'm'],
            'j': ['y', 'u', 'i', 'h', 'k', 'n', 'm'], 'k': ['u', 'i', 'o', 'j', 'l', 'm'], 'l': ['i', 'o', 'p', 'k'],
            'z': ['a', 's', 'x'], 'x': ['z', 's', 'd', 'c'], 'c': ['x', 'd', 'f', 'v'],
            'v': ['c', 'f', 'g', 'b'], 'b': ['v', 'g', 'h', 'n'], 'n': ['b', 'h', 'j', 'm'],
            'm': ['n', 'j', 'k']
        }
        
        # Common phonetic confusions
        phonetic_confusions = {
            'c': ['k', 's'], 'k': ['c'], 's': ['c', 'z'], 'z': ['s'],
            'f': ['ph'], 'ph': ['f'], 'i': ['y'], 'y': ['i'],
            'j': ['g'], 'g': ['j'], 'au': ['o'], 'o': ['au'],
            'tion': ['shun'], 'x': ['cks'], 'v': ['w'], 'w': ['v']
        }
        
        # Common double letters that might be mistyped as single
        common_doubles = ['tt', 'll', 'ss', 'mm', 'nn', 'pp', 'rr', 'ff', 'gg', 'dd']
        
        # Choose position weighted towards middle characters (more natural errors)
        pos = min(
            len(word)-1,
            max(0, int(random.gauss(len(word)/2, len(word)/3))))
        
        operation = random.choices(
            [0, 1, 2, 3],  # 0=insert, 1=delete, 2=substitute, 3=transpose
            weights=[0.2, 0.3, 0.4, 0.1],  # More likely to delete/substitute
            k=1
        )[0]
        
        # Insert operation
        if operation == 0:
            # 50% chance to insert a keyboard-adjacent character
            if random.random() < 0.5 and word[pos] in keyboard_adjacency:
                random_char = random.choice(keyboard_adjacency[word[pos]])
            else:
                random_char = chr(random.choice(range(97, 123)))
            return word[:pos] + random_char + word[pos:]
        
        # Delete operation
        elif operation == 1:
            # Check for double letters (more likely to delete one)
            if pos < len(word)-1 and word[pos] == word[pos+1] and random.random() < 0.7:
                return word[:pos] + word[pos+1:]
            return word[:pos] + word[pos+1:]
        
        # Substitute operation
        elif operation == 2:
            original_char = word[pos]
            
            # First try phonetic substitutions
            for pattern, replacements in phonetic_confusions.items():
                if word[pos:pos+len(pattern)] == pattern and random.random() < 0.4:
                    return word[:pos] + random.choice(replacements) + word[pos+len(pattern):]
            
            # Then try keyboard adjacency
            if original_char in keyboard_adjacency and random.random() < 0.7:
                random_char = random.choice(keyboard_adjacency[original_char])
                return word[:pos] + random_char + word[pos+1:]
            
            # Fallback to random substitution
            random_char = original_char
            while random_char == original_char:
                random_char = chr(random.choice(range(97, 123)))
            return word[:pos] + random_char + word[pos+1:]
        
        # Transpose operation (swap adjacent letters)
        elif operation == 3 and len(word) > 1:
            pos = random.randint(0, len(word)-2)
            return word[:pos] + word[pos+1] + word[pos] + word[pos+2:]
        
        return word  # fallback (should rarely happen)



    def __init__(self, model, vocabulary, tokenizer, device='cuda'):
        self.jumbler = TokenJumbler(vocabulary, tokenizer)
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        
        # Ensure the model is on the correct device
        #self.device)
        #self.model.to(self.device)
        self.model.eval()  # Set model to evaluation mode
    
    def generate_response(self, batch, max_new_tokens=512, batch_size=2, perplexity_outfile="out.out"):
        normal_batch_inputs = []
        perturbed_batch_inputs = []
        perplexity_out = open(perplexity_outfile, "w")
        #input_ids_prefix = self.tokenizer.encode("<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are a helpful assistant. I am going to give you a word. Repeat the word without any reasoning or explanation.\n<|eot_id|><|start_header_id|>user<|end_header_id|>", add_special_tokens=False)
        #input_ids_suffix = self.tokenizer.encode("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False)
        input_ids_prefix = self.tokenizer.encode("""
Question: Which of the two words contains a misspelling? Respond directly with the answer option.

Question: 

A. guarantees 
B. garantees

Answer: B
                                                 
Question: 

A. pper
B. paper

Answer: A
                                                 
Question: 

A. research 
B. reesearch

Answer: B
                                                 
Question: 

A. processing 
B. procesing

Answer: B
                                                 
Question: 

A. foood 
B. food

Answer: A
                                                 
Question: 

A. deadline
B. daedline

Answer: B
                                                 
Question: 

A. chicken 
B. cicken

Answer: B
                                                 
Question: 

A. duckling
B. dduckling

Answer: B
                                                 
Question: 

A. whiskye
B. whiskey

Answer: A
                                                 
Question: 

A. farmer
B. farme

Answer: B

Question:

A. """ , add_special_tokens=False)
        input_ids_suffix = self.tokenizer.encode("\n\nAnswer:", add_special_tokens=False)
        #input_ids_prefix = self.tokenizer.encode("<bos>\n\n", add_special_tokens=False) # Gemma
        #input_ids_suffix = self.tokenizer.encode("<end_of_turn><start_of_turn>model\n\n", add_special_tokens=False)
        #input_ids_prefix = self.tokenizer.encode("<|system|>\nYou are a helpful assistant. I am going to give you a word. Repeat the word without any reasoning or explanation.\n<|user|>\n", add_special_tokens=False)
        #input_ids_suffix = self.tokenizer.encode("\n<|assistant|>\n", add_special_tokens=False)
        groundtruth = open("groundtruth.txt", "w")
        # Process each sentence in the batch
        for sentence in tqdm(batch, desc="Processing Sentences", unit="sentence"):
            if isinstance(sentence, list):
                sentence = sentence[0]
            normal_tokenized = self.tokenizer.tokenize(sentence)
            jumble_tokenized = self.jumbler.jumble(normal_tokenized.copy())
            #print(jumble_tokenized)
            normal_input_ids = self.tokenizer.convert_tokens_to_ids(normal_tokenized)
            jumbled_input_ids = self.tokenizer.convert_tokens_to_ids(jumble_tokenized)
            normal_sequence = input_ids_prefix + normal_input_ids + input_ids_suffix
            #jumbled_sequence = input_ids_prefix + jumbled_input_ids + input_ids_suffix + postsentence_input_ids
            if (random.choice([0, 1]) == 0):
                groundtruth.write("A\n")
                jumbled_sequence = input_ids_prefix + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.natural_misspelling(sentence) + "\nB. ")) + jumbled_input_ids +  input_ids_suffix 
            else:
                groundtruth.write("B\n")
                jumbled_sequence = input_ids_prefix + jumbled_input_ids + self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize("\nB. " + self.natural_misspelling(sentence))) + input_ids_suffix 
            #normal_sequence = normal_input_ids
            #jumbled_sequence = jumbled_input_ids 
            print(self.tokenizer.convert_ids_to_tokens(jumbled_sequence))
            #normal_sequence = normal_input_ids + input_ids_suffix
            #jumbled_sequence = jumbled_input_ids + input_ids_suffix #Gemma only
            normal_batch_inputs.append(normal_sequence)
            perturbed_batch_inputs.append(jumbled_sequence)
        groundtruth.close()
        
        # Combine normal and perturbed inputs
        combined_input_ids = perturbed_batch_inputs
        #combined_input_ids = perturbed_batch_inputs
        
        # Determine the maximum sequence length in the combined batch
        max_length = max(len(seq) for seq in combined_input_ids)
        
        # Create padded input IDs and attention masks with left padding
        padded_input_ids = []
        attention_masks = []
        
        for seq in combined_input_ids:
            padding_length = max_length - len(seq)
            # Left pad the sequence
            padded_seq = [self.tokenizer.pad_token_id] * padding_length + seq
            attention_mask = [0] * padding_length + [1] * len(seq)
            padded_input_ids.append(padded_seq)
            attention_masks.append(attention_mask)
        
        # Convert to PyTorch tensors
        input_ids_tensor = torch.tensor(padded_input_ids, dtype=torch.long)
        attention_mask_tensor = torch.tensor(attention_masks, dtype=torch.long)
        
        # Initialize lists to store responses
        responses = []
        
        # Process the batch in chunks based on batch_size
        index = 0
        for i in range(0, len(combined_input_ids), batch_size):
            batch_input_ids = input_ids_tensor[i:i+batch_size]
            batch_attention_mask = attention_mask_tensor[i:i+batch_size]
            index += 1
            #print("Processing Batch:", index)
            
            # Generate responses with deterministic decoding (greedy)
            with torch.no_grad():
                generated_ids = self.model.generate(
                    input_ids=batch_input_ids,
                    attention_mask=batch_attention_mask,
                    max_new_tokens=max_new_tokens,
                    do_sample=False    # Ensures deterministic decoding
                )
            # Decode the generated token IDs to text
            #print(self.tokenizer.convert_ids_to_tokens(generated_ids.tolist()[0]))
            decoded_responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
            """for response in decoded_responses:
                print(response)
                print("//////////////////////////////////////////////////")
            print("----------------------------------------------------")"""
            responses.extend(decoded_responses)
        
        # Split the responses back into normal and perturbed
        half = len(normal_batch_inputs)
        normal_responses = responses[:half]
        perturbed_responses = responses[half:]
        
        return normal_responses, perturbed_responses