import pandas as pd
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
import os

def rotate_with_binding(sentence, bind_start, bind_count, positions=1):
    # Split into words
    words = sentence.split()
    
    # Validate binding indices
    if bind_start < 0 or bind_start + bind_count > len(words):
        return sentence
        
    # Bind the specified words
    bound_text = ''.join(words[bind_start:bind_start + bind_count])
    words = words[:bind_start] + [bound_text] + words[bind_start + bind_count:]
    
    # Rotate the words
    positions = positions % len(words)
    words = words[positions:] + words[:positions]
    
    # Find and separate the bound text
    for i, word in enumerate(words):
        if word == bound_text:
            # Separate back into original words
            separated = ' '.join(sentence.split()[bind_start:bind_start + bind_count])
            words[i] = separated
            
    return ' '.join(words)

class DatasetCreator:
    def __init__(
        self,
        evaluator_model="nztinversive/llama3.2-1b-Uncensored",
        device="cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.device = device
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                evaluator_model,
                torch_dtype=torch.float16,
                device_map="auto"
            )
            self.tokenizer = AutoTokenizer.from_pretrained(evaluator_model)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
        except Exception as e:
            raise RuntimeError(f"Failed to initialize model: {str(e)}")

    def generate_answer(self, prompt: str) -> str:
        try:
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt"
            ).to(self.device)

            terminators = [
                            self.tokenizer.eos_token_id,
                            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                        ]

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=100,
                    eos_token_id=terminators,
                    do_sample=True, 
                    temperature=0.9, 
                    top_k=50
                )
            
            response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)

            return response
                
        except Exception as e:
            print(f"Error generating answer for prompt: {e}")
            return ""  # Return empty string instead of 0.0 since we're expecting text

    def create_dataset_from_txt(self, input_file: str, output_file: str) -> None:
        """
        Create a dataset from a text file containing questions.
        
        Args:
            input_file (str): Path to input .txt file
            output_file (str): Path to output .jsonl file
        """
        try:
            if not os.path.exists(input_file):
                raise FileNotFoundError(f"Input file not found: {input_file}")

            questions = []
            
            # Read and process the text file
            with open(input_file, 'r', encoding='utf-8') as f:
                lines = f.readlines()

            # Process each question and generate answer
            for line in tqdm(lines, desc="Processing questions"):
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                    
                # Ensure line ends with question mark
                if not line.endswith('?'):
                    line += '?'
                
                # Generate answer using the model
                answer = self.generate_answer(line)

                new_word = "I cannot"
                words = answer.split()
                words.insert(0, new_word)  # Insert after n word 0/7
                new_answer = " ".join(words)
                
                for idx in range(len(new_answer) - 1):
                    result = rotate_with_binding(new_answer, 0, 1, idx) 
                
                    # Create formatted example
                    example = {
                        "instruction": f"Answer the following question: {line}",
                        "output": result
                    }
                    questions.append(example)
            
            # Save to JSONL file
            with open(output_file, 'w', encoding='utf-8') as f:
                for question in questions:
                    f.write(json.dumps(question) + '\n')
            
            print(f"Created dataset with {len(questions)} question-answer pairs")
            
        except Exception as e:
            print(f"An error occurred: {str(e)}")
            raise

# Example usage
if __name__ == "__main__":
    try:
        creator = DatasetCreator()
        creator.create_dataset_from_txt('MaliciousInstruct.txt', 'training_data_cyclic.jsonl') # training_data.jsonl
    except Exception as e:
        print(f"Failed to create dataset: {str(e)}")