"""
WTGIA Text Generation Utilities
Local Llama generation with strict word restrictions and no-topic prompts
"""

import json
import pickle
import re
from pathlib import Path
from typing import Dict, List, Optional
import os

import numpy as np
import torch
from tqdm import tqdm
from sklearn.feature_extraction.text import CountVectorizer
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor


class RestrictProcessor(LogitsProcessor):
    """LogitsProcessor to forbid specific tokens (not_used_words) during generation"""
    def __init__(self, tokenizer, non_target_tokens):
        self.tokenizer = tokenizer
        self.non_target_tokens = non_target_tokens
        all_specified_and_non_specified = set(non_target_tokens)
        self.stopwords = [i for i in range(tokenizer.vocab_size) if i not in all_specified_and_non_specified]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        scores[:, self.non_target_tokens] = -float('Inf')
        return scores


def load_bow_config(dataset_name: str) -> Dict[str, any]:
    """Load BoW configuration based on dataset name - simplified for no-topic"""
    configs = {
        'cora': {'max_tokens': 512, 'max_words': 300},
        'citeseer': {'max_tokens': 512, 'max_words': 300},
        'pubmed': {'max_tokens': 550, 'max_words': 400}
    }
    return configs.get(dataset_name.lower(), configs['cora'])


def clear_text(raw_text: str) -> str:
    """Clean generated text using original implementation logic"""
    raw_text = re.sub(r"\btitle:\s*", "", raw_text, flags=re.IGNORECASE)
    raw_text = re.sub(r"\btitle\b", "", raw_text, flags=re.IGNORECASE)
    raw_text = re.sub(r"\babstract:\s*", "", raw_text, flags=re.IGNORECASE)
    raw_text = re.sub(r"\babstract\b", "", raw_text, flags=re.IGNORECASE)
    raw_text = raw_text.replace("\n", " ")
    raw_text = raw_text.replace("\"", "")
    return raw_text


class WTGIATextGenerator:
    """Simplified WTGIA Text Generator using topic-aware prompts"""
    
    def __init__(self, dataset_name: str, base_path: str = "/path/to/GraphAD_data"):
        self.dataset_name = dataset_name.lower()
        self.base_path = Path(base_path)
        self.config = load_bow_config(self.dataset_name)
        self.vocabulary = self._load_vocabulary()
    
    def _load_vocabulary(self) -> np.ndarray:
        """Load BoW vocabulary for the dataset"""
        vocab_path = self.base_path / "datasets" / "vocab" / self.dataset_name / "bow_vocabulary.pkl"
        
        if not vocab_path.exists():
            raise FileNotFoundError(f"Vocabulary not found at {vocab_path}")
        
        with open(vocab_path, 'rb') as f:
            vectorizer = pickle.load(f)
        
        return vectorizer.get_feature_names_out()
    
    def extract_words_from_features(self, features: torch.Tensor) -> tuple:
        """Extract used and not_used words from BoW features"""
        used_words = []
        not_used_words = []
        
        for node_features in features:
            used = [self.vocabulary[i] for i in range(len(self.vocabulary)) if node_features[i] == 1]
            not_used = [self.vocabulary[i] for i in range(len(self.vocabulary)) if node_features[i] == 0]
            used_words.append(used)
            not_used_words.append(not_used)
        
        return used_words, not_used_words
    
    def create_no_topic_prompt(self, used_words: List[str]) -> List[Dict[str, str]]:
        """Create simple no-topic prompt for faster generation"""
        max_words = self.config['max_words']
        
        # Create more aggressive prompt for maximum word usage
        required_words_str = ", ".join(f"'{word}'" for word in used_words)
        word_count = len(used_words)
        
        content = (
            "Generate a title and an abstract for an academic article.\n" +
            f"CRITICAL REQUIREMENT: You MUST include ALL {word_count} of these specific words: {required_words_str}\n" +
            "Use each word exactly as written - no synonyms, plural forms, or variants allowed.\n" +
            "PRIORITY: Including all required words is more important than natural language flow.\n" +
            "You must use all the required words !!!.\n" +
            f"Length limit: {max_words} words.\n" +
            "Output format:\nTITLE: [your title]\nABSTRACT: [your abstract]"
        )
        
        return [
            {"role": "system", "content": "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful answers."},
            {"role": "user", "content": content}
        ]
    
    def calculate_usage_rates(self, text: str, should_use_words: List[str], should_not_use_words: List[str]) -> tuple:
        """Calculate word usage rates in generated text"""
        text = text.lower().split()
        text_words = [subpart for part in text for subpart in part.split('-')]
        non_use = []
        
        should_use_count = 0
        should_not_use_count = 0
        
        for word in should_use_words:
            if word in text_words:
                should_use_count += 1
            else:
                non_use.append(word)
        for word in should_not_use_words:
            if word in text_words:
                should_not_use_count += 1

        should_use_rate = (should_use_count / len(should_use_words)) * 100 if len(should_use_words) > 0 else 0
        should_not_use_rate = (should_not_use_count / len(should_not_use_words)) * 100 if len(should_not_use_words) > 0 else 0
        
        return should_use_rate, should_not_use_rate, non_use
    
    def generate_texts(self, features_attack: torch.Tensor, llama_client) -> List[str]:
        """Generate texts using Llama with strict word restrictions - optimized"""
        used_words, not_used_words = self.extract_words_from_features(features_attack)
        generated_texts = []
        max_tokens = self.config['max_tokens']
        
        print(f"Generating {len(used_words)} texts with no-topic prompts...")
        
        for i, (used_word, not_used_word) in enumerate(tqdm(zip(used_words, not_used_words), desc="Generating", total=len(used_words))):
            if not used_word:
                generated_texts.append("Generated academic text.")
                continue
            
            max_rate = 0
            best_text = ""
            
            # Initial generation with restrictions
            messages = self.create_no_topic_prompt(used_word)
            response = llama_client.generate(messages, max_tokens, not_used_word)
            
            use_rate, not_use_rate, missing_words = self.calculate_usage_rates(response, used_word, not_used_word)
            
            if use_rate >= max_rate:
                max_rate = use_rate
                best_text = response
            
            # Try up to 3 correction rounds - more aggressive for used_words priority
            for round_num in range(1):
                    
                # More aggressive feedback
                missing_count = len(missing_words)
                feedback = (f"CRITICAL ERROR: You missed {missing_count} required words: " + 
                           ', '.join(f'\'{word}\'' for word in missing_words) + ".\n" +
                           "You MUST include every single required word. Rewrite to include ALL missing words.\n" +
                           "TITLE: [rewrite title]\nABSTRACT: [rewrite abstract]")
                
                messages.append({"role": "assistant", "content": response})
                messages.append({"role": "user", "content": feedback})
                
                response = llama_client.generate(messages, max_tokens, not_used_word)
                use_rate, not_use_rate, missing_words = self.calculate_usage_rates(response, used_word, not_used_word)
                
                if use_rate >= max_rate:
                    max_rate = use_rate
                    best_text = response
            
            # Clean text
            if self.dataset_name != 'pubmed':
                cleaned_text = clear_text(best_text)
            else:
                cleaned_text = best_text.replace("\n", " ")
            
            # Less verbose logging
            if i % 10 == 0:
                print(f"Node {i}: Use Rate {max_rate:.1f}%, Not Use Rate {not_use_rate:.1f}%")
            
            generated_texts.append(cleaned_text)
        
        return generated_texts


class LlamaClient:
    """Optimized Local Llama client with token restrictions"""
    
    def __init__(self, model_path: str = "/path/to/models/llama-3.1-8B-Instruct/"):
        self.model_path = model_path
        print(f"Loading Llama model from {model_path}...")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Set pad_token to avoid attention mask warnings
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        
        print(f"✓ Llama model loaded on device: {self.model.device}")
    
    def _get_forbidden_tokens(self, not_used_words: List[str]) -> List[int]:
        """Pre-compute forbidden tokens for efficiency"""
        if not not_used_words:
            return []
        
        not_used_tokens = set()
        
        for word in not_used_words:
            try:
                # Regular word
                tokens = self.tokenizer.encode(word, add_special_tokens=False)
                not_used_tokens.update(tokens)
                
                # Capitalized word
                cap_word = word.capitalize()
                cap_tokens = self.tokenizer.encode(cap_word, add_special_tokens=False)
                not_used_tokens.update(cap_tokens)
            except:
                continue
                
        return list(not_used_tokens)
    
    def generate(self, messages: List[Dict[str, str]], max_tokens: int = 250, not_used_words: List[str] = None) -> str:
        """Generate response using local Llama with word restrictions"""
        try:
            # Prepare token restrictions
            forbidden_tokens = self._get_forbidden_tokens(not_used_words or [])
            custom_processor = RestrictProcessor(self.tokenizer, forbidden_tokens)
            logits_processor = LogitsProcessorList([custom_processor])
            
            # Apply chat template
            input_ids = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(self.model.device)
            
            # Create attention mask to avoid warnings
            attention_mask = torch.ones_like(input_ids).to(self.model.device)
            
            # Generate with settings optimized for instruction following
            with torch.inference_mode():
                outputs = self.model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=max_tokens,
                    eos_token_id=self.terminators,
                    pad_token_id=self.tokenizer.pad_token_id,
                    do_sample=True,
                    temperature=0.3,  # Low temp for instruction following
                    logits_processor=logits_processor,
                    use_cache=True
                )
            
            # Decode only the new tokens
            response = outputs[0][input_ids.shape[-1]:]
            text = self.tokenizer.decode(response, skip_special_tokens=True)
            
            return text.strip()
            
        except Exception as e:
            print(f"Generation error: {e}")
            return "Generated academic text"


# Main API function
def generate_wtgia_texts(
    features_attack: torch.Tensor,
    dataset_name: str,
    base_path: str = "/path/to/GraphAD_data",
    model_path: str = "/path/to/models/llama-3.1-8B-Instruct/",
    save_dir: Optional[str] = None
) -> List[str]:
    """
    Main API function for WTGIA text generation using local Llama with strict restrictions
    
    Args:
        features_attack: Optimized BoW features from WTGIA
        dataset_name: Dataset name
        base_path: Base data path
        model_path: Path to local Llama model
        save_dir: Directory to save results
        
    Returns:
        List of generated texts
    """
    print(f"Generating texts for {len(features_attack)} injected nodes using Llama with restrictions...")
    
    # Initialize generator
    generator = WTGIATextGenerator(dataset_name, base_path)
    
    # Initialize Llama client
    llm_client = LlamaClient(model_path)
    
    # Generate texts with restrictions
    texts = generator.generate_texts(features_attack, llm_client)
    
    # Save if requested
    if save_dir:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        
        # Save individual files
        for i, text in enumerate(texts):
            with open(save_dir / f"node_{i}.txt", 'w', encoding='utf-8') as f:
                f.write(text)
        
        # Save consolidated JSON
        output_data = {
            'texts': texts,
            'dataset': dataset_name,
            'num_texts': len(texts),
            'config': generator.config,
            'model_path': model_path
        }
        
        with open(save_dir / "generated_texts.json", 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        
        print(f"✓ Saved {len(texts)} texts to {save_dir}")
    
    # Clean up GPU memory
    if hasattr(llm_client, 'model'):
        del llm_client.model
    del llm_client
    torch.cuda.empty_cache()
    
    print(f"✓ Successfully generated {len(texts)} texts with restrictions")
    return texts


def compute_bow_embeddings_from_texts(texts, dataset_name, vocab_path, original_features=None):
    """
    Compute BoW embeddings from generated texts using existing vocabulary
    
    Args:
        texts (List[str]): List of generated texts
        dataset_name (str): Name of the dataset
        vocab_path (str): Path to saved vocabulary pickle file
        original_features (torch.Tensor, optional): Original optimized features for comparison
        
    Returns:
        torch.Tensor: BoW embeddings tensor
    """
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"Vocabulary file not found: {vocab_path}")
    
    # Load existing vocabulary
    with open(vocab_path, 'rb') as f:
        vectorizer = pickle.load(f)
    
    # Transform texts to BoW vectors using existing vocabulary
    bow_matrix = vectorizer.transform(texts)
    
    # Convert to tensor
    embeddings = torch.FloatTensor(bow_matrix.toarray())
    
    print(f"Computed BoW embeddings for {len(texts)} texts: {embeddings.shape}")
    
    # Simple validation check
    if original_features is not None:
        print("\nValidation Check:")
        
        # 1. Text-to-embedding consistency
        vocab = vectorizer.get_feature_names_out()
        text_match_rate = 0
        for i in range(min(3, len(texts))):
            text_words = texts[i].lower().split()
            text_words = [w for part in text_words for w in part.split('-')]
            
            matches = 0
            for j, word in enumerate(vocab):
                word_in_text = word in text_words
                feature_present = embeddings[i][j].item() > 0
                if word_in_text == feature_present:
                    matches += 1
            text_match_rate += matches / len(vocab)
        
        print(f"  Text-embedding consistency: {text_match_rate/min(3, len(texts))*100:.1f}%")
        
        # 2. Original vs re-computed similarity
        orig_cpu = original_features.cpu() if original_features.is_cuda else original_features
        similarity = torch.cosine_similarity(orig_cpu, embeddings).mean().item()
        
        # Feature overlap
        orig_ones = (orig_cpu > 0).float().mean().item()
        recomp_ones = (embeddings > 0).float().mean().item()
        overlap = ((orig_cpu > 0) & (embeddings > 0)).float().mean().item()
        
        print(f"  Cosine similarity with original: {similarity:.3f}")
        print(f"  Feature density - Original: {orig_ones:.3f}, Re-computed: {recomp_ones:.3f}")
        print(f"  Feature overlap rate: {overlap/orig_ones*100:.1f}% of original words kept")
    
    return embeddings


def create_bow_vectorizer_from_texts(texts, vocab_size=None, min_df=1, max_df=1.0):
    """
    Create a new BoW vectorizer from texts (if needed for standalone usage)
    
    Args:
        texts (List[str]): List of texts to create vocabulary from
        vocab_size (int, optional): Maximum vocabulary size
        min_df (int): Minimum document frequency
        max_df (float): Maximum document frequency
        
    Returns:
        CountVectorizer: Fitted vectorizer
        torch.Tensor: BoW embeddings tensor
    """
    vectorizer = CountVectorizer(
        max_features=vocab_size,
        min_df=min_df,
        max_df=max_df,
        lowercase=True,
        stop_words='english'
    )
    
    # Fit and transform
    bow_matrix = vectorizer.fit_transform(texts)
    embeddings = torch.FloatTensor(bow_matrix.toarray())
    
    print(f"Created BoW vectorizer with {len(vectorizer.vocabulary_)} vocabulary terms")
    print(f"Embeddings shape: {embeddings.shape}")
    
    return vectorizer, embeddings


def save_wtgia_words(features_attack: torch.Tensor, dataset_name: str, save_dir: str, vocab_path: str, ptb_rate: float, seed: int):
    """
    Save used_words and not_used_words arrays like in generate_raw_text.py
    
    Args:
        features_attack: BoW features from WTGIA
        dataset_name: Dataset name
        save_dir: Directory to save arrays
        vocab_path: Path to vocabulary pickle file
        ptb_rate: Perturbation rate
        seed: Random seed
    """
    # Load vocabulary
    with open(vocab_path, 'rb') as f:
        vectorizer = pickle.load(f)
    words = vectorizer.get_feature_names_out()
    
    used_words = []
    not_used_words = []
    
    for doc in features_attack:
        used = [words[i] for i in range(len(words)) if doc[i] == 1]
        not_used = [words[i] for i in range(len(words)) if doc[i] == 0]
        used_words.append(used)
        not_used_words.append(not_used)
    
    used_words = np.array(used_words, dtype=object)
    not_used_words = np.array(not_used_words, dtype=object)
    
    # Create raw directory
    raw_dir = Path(save_dir) / "raw"
    raw_dir.mkdir(parents=True, exist_ok=True)
    
    # Save arrays with consistent naming
    file_prefix = f"{dataset_name}_wtgia_{int(ptb_rate*100)}_{seed}"
    np.save(raw_dir / f"{file_prefix}_used.npy", used_words)
    np.save(raw_dir / f"{file_prefix}_not_used.npy", not_used_words)
    
    print(f"✓ Saved word arrays to {raw_dir}")
    print(f"  - Used words: {file_prefix}_used.npy")
    print(f"  - Not used words: {file_prefix}_not_used.npy")
    
    return used_words, not_used_words