import os
import re
import torch
import logging
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    CLIPModel,
    CLIPTokenizer
)

# Configuration
class Config:
    qwen_path = "///"
    clip_model_name = "///"
    
    # Generation parameters
    batch_size = 150
    max_new_tokens = 1200
    temperature = 1
    
    # GPU settings
    num_gpus = 2
    
    # Storage base path
    save_dir = r"///"

    @staticmethod
    def get_text_path(dataset_type):
        return os.path.join(Config.save_dir, "text", dataset_type, "categories.txt")
    
    @staticmethod
    def get_embeddings_path(dataset_type):
        return os.path.join(Config.save_dir, "embeddings", dataset_type, "embeddings.pt")

# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def ensure_dir(path):
    """Create directory structure"""
    os.makedirs(os.path.dirname(path), exist_ok=True)

class CategoryGenerator:
    def __init__(self, seen_categories=None):
        self.seen_categories = set(seen_categories) if seen_categories else set()
        self.tokenizer = AutoTokenizer.from_pretrained(
            Config.qwen_path,
            trust_remote_code=True,
            local_files_only=True
        )
        
        self.model = AutoModelForCausalLM.from_pretrained(
            Config.qwen_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            max_memory={i: "///#eg.24GB" for i in range(Config.num_gpus)},
            trust_remote_code=True,
            local_files_only=True
        ).eval()

    def _build_prompt(self, count):
        return f"""Generate {count} specific and unique English object categories covering:
        - Natural objects (e.g., Maple Tree, Quartz Crystal)
        - Manufactured items (e.g., Bluetooth Headphones, Hydraulic Press)
        - Abstract concepts (e.g., Cloud Storage, AI Assistant)

        Requirements:
        1. Each entry must be a singular noun phrase
        2. Avoid generic terms (use "Digital SLR Camera" not "Camera")
        3. Strict formatting:

        {count} examples of valid format:
        CategoryName
        CategoryName
        ...
        {count}. """

    def generate_batch(self, batch_size):
        prompt = self._build_prompt(batch_size)
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            return_token_type_ids=False
        ).to(self.model.device)
        
        with torch.cuda.amp.autocast():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=Config.max_new_tokens,
                temperature=Config.temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                top_p=0.9,
                num_return_sequences=1
            )
        
        generated_text = self.tokenizer.decode(
            outputs[0][len(inputs.input_ids[0]):], 
            skip_special_tokens=True
        )
        return self._parse_output(generated_text)

    def _parse_output(self, text):
        pattern = r"\b\d+[\.\)]\s*([A-Za-z][\w\s-]+)"
        matches = re.findall(pattern, text)
        return [m.strip().title() for m in matches if 2 <= len(m.strip()) <= 40]

class EmbeddingProcessor:
    def __init__(self):
        self.clip_tokenizer = CLIPTokenizer.from_pretrained(
            Config.clip_model_name,
            local_files_only=True
        )
        
        self.clip_model = CLIPModel.from_pretrained(
            Config.clip_model_name,
            local_files_only=True
        )
        self.clip_model = torch.nn.DataParallel(self.clip_model.cuda())
        self.embedding_batch_size = 512

    def process(self, categories):
        embeddings = []
        for i in tqdm(range(0, len(categories), self.embedding_batch_size),
                    desc="Processing embeddings",
                    unit="batch"):
            batch = categories[i:i+self.embedding_batch_size]
            inputs = self.clip_tokenizer(
                batch, 
                padding=True, 
                return_tensors="pt", 
                truncation=True,
                max_length=10
            ).to("cuda")
            
            with torch.no_grad(), torch.cuda.amp.autocast():
                batch_embeds = self.clip_model.module.get_text_features(**inputs).cpu()
            embeddings.append(batch_embeds)
        return torch.cat(embeddings)

def generate_dataset(target_count, dataset_type, exclude_train=False):
    # Assertions
    text_path = Config.get_text_path(dataset_type)
    embeddings_path = Config.get_embeddings_path(dataset_type)
    ensure_dir(text_path)
    ensure_dir(embeddings_path)
    
    
    initial_seen = set()
    if exclude_train and dataset_type == "val":
        train_text_path = Config.get_text_path("train")
        if os.path.exists(train_text_path):
            with open(train_text_path, "r", encoding="utf-8") as f:
                train_categories = [line.strip() for line in f]
            initial_seen = {c.lower() for c in train_categories}
        else:
            logging.warning("Training categories not found, validation set may contain duplicates")

    generator = CategoryGenerator(seen_categories=initial_seen)
    processor = EmbeddingProcessor()
    
    all_categories = []
    progress = tqdm(total=target_count, desc=f"Generating {dataset_type} categories")
    
    while len(all_categories) < target_count:
        remaining = target_count - len(all_categories)
        current_batch = min(Config.batch_size, remaining)
        
        new_candidates = generator.generate_batch(current_batch)
        unique_new = [c for c in new_candidates 
                     if c.lower() not in generator.seen_categories]
        
        all_categories.extend(unique_new)
        generator.seen_categories.update({c.lower() for c in unique_new})
        
        progress.update(len(unique_new))
        progress.set_postfix({
            "Unique": len(all_categories),
            "Duplicates": len(new_candidates)-len(unique_new)
        })
        
        if len(all_categories) >= target_count:
            break

    final_categories = all_categories[:target_count]
    embeddings = processor.process(final_categories)
    
    with open(text_path, "w", encoding="utf-8") as f:
        f.write("\n".join(final_categories))
    
    torch.save(embeddings, embeddings_path)
    logging.info(f"{dataset_type.upper()} dataset: {len(final_categories)} categories")
    logging.info(f"Embedding shape: {embeddings.shape}")
    logging.info(f"Saved to: {text_path}")

def main():
    torch.backends.cudnn.benchmark = True
    
    # train (4000)
    logging.info("\n" + "="*40 + " Generating TRAIN set " + "="*40)
    generate_dataset(3000, "train")
    
    # val (1000)
    logging.info("\n" + "="*40 + " Generating VALIDATION set " + "="*40)
    generate_dataset(2000, "val", exclude_train=True)

if __name__ == "__main__":
    main()