#!/usr/bin/env python3
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import sys
import numpy as np

from sentence_transformers import SentenceTransformer

class QwenChatBot:
    def __init__(self, model_name="Qwen/Qwen2.5-7B-Instruct"):
        """Initialize the Qwen chat bot"""
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.conversation_history = []
        self.device = None
    
    def load_sbert(self):
        """Load the SBERT model"""
        print(f"Loading SBERT model...")
        self.sbert = SentenceTransformer("all-MiniLM-L6-v2")
        print("✅ SBERT model loaded successfully!")
    
    def calculate_entropy_of_similarity_of_current_reply(self, response):
        """Calculate the similarity of the current reply to the initial prompt"""
        # parse the response to a list of sentences
        sentences = response.strip().split("\n")
        # print(f"Sentences: {sentences}")
        # remove everything before the first space
        sentences = [sentence[sentence.find(" ")+1:] for sentence in sentences]
        # remove the empty sentences
        sentences = [sentence for sentence in sentences if sentence.strip()]
        # sentences = [sentence.split(" ")[1:] for sentence in sentences]
        print(f"Parsed sentences: {sentences}")
        # calculate the embeddings
        initial_prompt_emb = self.sbert.encode(self.initial_prompt, convert_to_tensor=True).to(self.device)
        sentences_emb = self.sbert.encode(sentences, convert_to_tensor=True).to(self.device)

        # calculate the similarity of the current reply to the initial prompt
        similarity_to_previous_prompt = self.sbert.similarity(initial_prompt_emb, sentences_emb)[0]
        similarity_to_each_other = torch.tensor([self.sbert.similarity(sentences_emb[i], sentences_emb[j]).item() for i in range(len(sentences_emb)) for j in range(i+1, len(sentences_emb))])
        # import pdb; pdb.set_trace()
        print(f"Similarities to previous prompt: {similarity_to_previous_prompt.tolist()}")
        print(f"Similarities to each other: {similarity_to_each_other}")
        # print(f"Max-Min of similarity to previous prompt: {torch.max(similarity_to_previous_prompt).item() - torch.min(similarity_to_previous_prompt).item():.3f}")
        # print(f"Max-Min of similarity to each other: {torch.max(similarity_to_each_other).item() - torch.min(similarity_to_each_other).item():.3f}")
        # normalize the similarity

        mean_similarity_to_previous_prompt = torch.mean(similarity_to_previous_prompt).item()
        mean_similarity_to_each_other = torch.mean(similarity_to_each_other).item()

        similarity_to_each_other = torch.topk(similarity_to_each_other, k=len(sentences), dim=0)[0]
        # similarity_to_previous_prompt = similarity_to_previous_prompt / torch.sum(similarity_to_previous_prompt).item()
        # similarity_to_each_other = similarity_to_each_other / torch.sum(similarity_to_each_other).item()

        similarity_to_previous_prompt = torch.softmax(similarity_to_previous_prompt, dim=0)
        similarity_to_each_other = torch.softmax(similarity_to_each_other, dim=0)

        print(f"Normalized similarity to previous prompt: {similarity_to_previous_prompt.tolist()}")
        print(f"Normalized similarity to each other: {similarity_to_each_other}")
        # print(f"Max-Min of normalized similarity to previous prompt: {torch.max(similarity_to_previous_prompt).item() - torch.min(similarity_to_previous_prompt).item():.3f}")
        # print(f"Max-Min of normalized similarity to each other: {torch.max(similarity_to_each_other).item() - torch.min(similarity_to_each_other).item():.3f}")
        # calculate the entropy of the similarity
        entropy_to_previous_prompt = -torch.sum(similarity_to_previous_prompt * (torch.log(similarity_to_previous_prompt) + 1e-10)).item()
        entropy_to_each_other = -torch.sum(similarity_to_each_other * (torch.log(similarity_to_each_other) + 1e-10)).item()


        print(f"Entropy of similarity to previous prompt: {entropy_to_previous_prompt:.3f}, Normalized: {entropy_to_previous_prompt / np.log(len(sentences)):.3f}, Mean: {mean_similarity_to_previous_prompt:.3f}")
        print(f"Entropy of similarity to each other: {entropy_to_each_other:.3f}, Normalized: {entropy_to_each_other / np.log(len(similarity_to_each_other)):.3f}, Mean: {mean_similarity_to_each_other:.3f}")
        print(f"Score: {(entropy_to_previous_prompt - entropy_to_each_other) - (mean_similarity_to_each_other) + 1 * (mean_similarity_to_previous_prompt):.3f}")
        return entropy_to_previous_prompt#, entropy_to_each_other


    def calculate_split_score(self, node_prompt, response):
        # parse the response to a list of sentences
        sentences = response.strip().split("\n")
        # remove everything before the first space
        sentences = [sentence[sentence.find(" ")+1:] for sentence in sentences]
        # remove the empty sentences
        sentences = [sentence for sentence in sentences if sentence.strip()]
        print(f"Parsed sentences: {sentences}")
        # calculate the embeddings
        node_prompt_emb = self.sbert.encode(node_prompt, convert_to_tensor=True).to(self.device)
        sentences_emb = self.sbert.encode(sentences, convert_to_tensor=True).to(self.device)

        # calculate the similarity of the current reply to the initial prompt
        similarity_to_previous_prompt = self.sbert.similarity(node_prompt_emb, sentences_emb)[0]
        similarity_to_each_other = torch.tensor([self.sbert.similarity(sentences_emb[i], sentences_emb[j]).item() for i in range(len(sentences_emb)) for j in range(i+1, len(sentences_emb))])

        print(f"Similarities to previous prompt: {similarity_to_previous_prompt.tolist()}")
        print(f"Similarities to each other: {similarity_to_each_other}")

        mean_similarity_to_previous_prompt = torch.mean(similarity_to_previous_prompt).item()
        mean_similarity_to_each_other = torch.mean(similarity_to_each_other).item()

        score = -(mean_similarity_to_each_other) + 1 * (mean_similarity_to_previous_prompt)
        print(f"Score: {score:.3f}")
        return score, similarity_to_previous_prompt.tolist()

    def calculate_growth_score(self, initial_prompt, node_prompt,response):
        # parse the response to a list of sentences
        sentences = response.strip().split("\n")
        # remove everything before the first space
        sentences = [sentence[sentence.find(" ")+1:] for sentence in sentences]
        # remove the empty sentences
        sentences = [sentence for sentence in sentences if sentence.strip()]

        # calculate the embeddings
        initial_prompt_emb = self.sbert.encode(initial_prompt, convert_to_tensor=True).to(self.device)
        node_prompt_emb = self.sbert.encode(node_prompt, convert_to_tensor=True).to(self.device)
        sentences_emb = self.sbert.encode(sentences, convert_to_tensor=True).to(self.device)

        # calculate the similarity of the current reply to the initial prompt
        mean_similarity_to_initial_prompt = torch.mean(self.sbert.similarity(initial_prompt_emb, sentences_emb)[0]).item()
        initial_prompt_similarity_to_node_prompt = self.sbert.similarity(initial_prompt_emb, node_prompt_emb)[0].item()

        if mean_similarity_to_initial_prompt < initial_prompt_similarity_to_node_prompt:
            # if the children is more diverse, then it is a good growth
            return 1
        else:
            return 0

    def parse_sub_task_tree(self, node_prompt=None, depth=0, max_depth=3):
        """
        Parse the initial prompt into a sub-task tree.
        For each node, iterate children number from 2 to 8. When split score decreases, use k-1 as best split.
        Calculate growth score - if 0, make leaf node. If 1, split and recurse.
        """
        # Initialize with initial prompt if this is the root call
        if node_prompt is None:
            node_prompt = self.initial_prompt
            print(f"\n🌳 Starting sub-task tree parsing...")
            print(f"Initial prompt: {node_prompt}")
            print("-" * 80)
            
        assert self.initial_prompt is not None, "Initial prompt is not loaded"
        
        # Prevent infinite recursion
        if depth >= max_depth:
            print(f"{'  ' * depth}📍 Max depth ({max_depth}) reached, creating leaf node")
            return {"prompt": node_prompt, "children": [], "is_leaf": True, "depth": depth}
        
        print(f"{'  ' * depth}🔍 Processing node at depth {depth}: {node_prompt[:80]}...")
        
        best_score = -10000
        best_k = None
        best_response = None
        
        # Iterate children number from 2 to 8
        for k in range(2, 9):
            print(f"{'  ' * depth}  Testing k={k} children...")
            
            split_prompt = f"""Given the task: "{node_prompt}", divide it into {k} distinct, non-overlapping sub-tasks covering different semantic dimensions. Each sub-task must be clearly different from the others. Output should be a numberd list with {k} sub-tasks only in short sentences or key words."""
            
            # Generate response
            self.conversation_history = []
            response = self.generate_response(split_prompt, max_new_tokens=256, temperature=0.5)
            
            # Calculate split score
            score = self.calculate_split_score(node_prompt, response)
            
            print(f"{'  ' * depth}    k={k}, score={score:.3f}")
            
            # Check if this is better than previous scores
            if score > best_score:
                best_score = score
                best_k = k
                best_response = response
            elif score < best_score and best_k is not None:
                # Score decreased, so k-1 was the best split
                print(f"{'  ' * depth}    Score decreased, stopping at k={best_k}")
                break
        
        # Check if we found a valid split
        if best_k is None:
            print(f"{'  ' * depth}❌ No good split found, creating leaf node")
            return {"prompt": node_prompt, "children": [], "is_leaf": True, "depth": depth}
        
        print(f"{'  ' * depth}✅ Best split: k={best_k}, score={best_score:.3f}")
        
        # Calculate growth score to determine if we should actually split
        growth_score = self.calculate_growth_score(self.initial_prompt, node_prompt, best_response)
        
        if growth_score == 0:
            print(f"{'  ' * depth}🛑 Growth score is 0, creating leaf node")
            return {"prompt": node_prompt, "children": [], "is_leaf": True, "depth": depth}
        
        print(f"{'  ' * depth}🚀 Growth score is 1, proceeding with split")
        
        # Parse the response into sub-tasks
        sentences = best_response.strip().split("\n")
        # Remove everything before the first space (removes numbering)
        sentences = [sentence[sentence.find(" ")+1:] if sentence.find(" ") != -1 else sentence 
                    for sentence in sentences]
        # Remove empty sentences
        sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
        
        print(f"{'  ' * depth}📝 Extracted {len(sentences)} sub-tasks:")
        for i, task in enumerate(sentences[:best_k]):
            print(f"{'  ' * depth}    {i+1}. {task[:60]}{'...' if len(task) > 60 else ''}")
        
        # Create child nodes recursively
        children = []
        for i, sub_task in enumerate(sentences[:best_k]):
            print(f"{'  ' * depth}🔄 Recursively processing child {i+1}...")
            child = self.parse_sub_task_tree(sub_task, depth + 1, max_depth)
            children.append(child)
        
        return {"prompt": node_prompt, "children": children, "is_leaf": False, "depth": depth}
    
    def print_task_tree(self, tree_node=None, depth=0):
        """Print the task tree in a nice hierarchical format"""
        if tree_node is None:
            if hasattr(self, 'task_tree'):
                tree_node = self.task_tree
            else:
                print("❌ No task tree found. Run parse_sub_task_tree first.")
                return
        
        indent = "  " * depth
        if tree_node["is_leaf"]:
            print(f"{indent}🍃 {tree_node['prompt']}")
        else:
            print(f"{indent}🌿 {tree_node['prompt']}")
            for child in tree_node["children"]:
                self.print_task_tree(child, depth + 1)
    
    def run_task_decomposition(self, max_depth=4):
        """Run the complete task decomposition process"""
        if not hasattr(self, 'initial_prompt') or self.initial_prompt is None:
            print("❌ No initial prompt loaded. Use the 'initial' command first.")
            return
        
        print("🎯 Starting task decomposition...")
        self.task_tree = self.parse_sub_task_tree(max_depth=max_depth)
        
        print("\n" + "="*80)
        print("🌳 FINAL TASK TREE:")
        print("="*80)
        self.print_task_tree()
        print("="*80)
        
        return self.task_tree


    @staticmethod
    def calculate_entropy(similarity):
        """Calculate the entropy of the similarity"""
        return -torch.sum(similarity * torch.log(similarity)).item()
    

    def load_model(self):
        """Load the Qwen model and tokenizer"""
        print(f"Loading {self.model_name}...")
        
        self.load_sbert()
        # Check CUDA availability
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print(f"🚀 Using CUDA: {torch.cuda.get_device_name()}")
            print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        else:
            self.device = torch.device("cpu")
            print("⚠️  CUDA not available, using CPU")
        
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
                trust_remote_code=True
            )
            
            # Move model to CUDA if available and not using device_map
            if torch.cuda.is_available() and not hasattr(self.model, 'hf_device_map'):
                self.model = self.model.to(self.device)
            
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            
            # Enable CUDA optimizations
            if torch.cuda.is_available():
                torch.backends.cudnn.benchmark = True
                # Enable memory efficient attention if available
                if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
                    torch.backends.cuda.enable_flash_sdp(True)
            
            print("✅ Model loaded successfully!")
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            sys.exit(1)
    
    def generate_response(self, user_input, max_new_tokens=512, temperature=0.7, do_sample=True):
        """Generate a response to user input"""
        # Add user message to conversation history
        self.conversation_history.append({"role": "user", "content": user_input})
        
        # Prepare messages with system prompt and conversation history
        messages = [
            {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}
        ] + self.conversation_history
        
        # Apply chat template
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Tokenize input
        model_inputs = self.tokenizer([text], return_tensors="pt")
        if torch.cuda.is_available():
            model_inputs = model_inputs.to(self.device)
        
        # Generate response
        try:
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **model_inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=do_sample,
                    temperature=temperature,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # Decode response
            generated_ids = [
                output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]
            response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
            
            # Add assistant response to conversation history
            self.conversation_history.append({"role": "assistant", "content": response})
            
            return response
            
        except Exception as e:
            print(f"Error generating response: {e}")
            return "Sorry, I encountered an error while generating a response."
    
    def clear_history(self):
        """Clear the conversation history"""
        self.conversation_history = []
        print("Conversation history cleared.")
    
    def clear_gpu_memory(self):
        """Clear GPU memory cache"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("GPU memory cache cleared.")
        else:
            print("No GPU available to clear memory.")
    
    def show_gpu_info(self):
        """Show current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            cached = torch.cuda.memory_reserved() / 1024**3
            total = torch.cuda.get_device_properties(0).total_memory / 1024**3
            print("\n🔧 GPU Memory Usage:")
            print(f"  Allocated: {allocated:.2f} GB")
            print(f"  Cached: {cached:.2f} GB")
            print(f"  Total: {total:.2f} GB")
            print(f"  Free: {total - allocated:.2f} GB")
        else:
            print("No GPU available.")
    
    def show_history(self):
        """Display the conversation history"""
        if not self.conversation_history:
            print("No conversation history.")
            return
        
        print("\n--- Conversation History ---")
        for i, message in enumerate(self.conversation_history, 1):
            role = message["role"].title()
            content = message["content"]
            print(f"{i}. {role}: {content}")
        print("--- End of History ---\n")
    
    def start_chat(self):
        """Start the interactive chat session"""
        print("🤖 Qwen Chat Bot")
        print("Type 'quit' or 'exit' to end the conversation")
        print("Type 'clear' to clear conversation history")
        print("Type 'history' to show conversation history")
        print("Type 'gpu' to show GPU memory info")
        print("Type 'cleargpu' to clear GPU memory cache")
        print("Type 'help' to show available commands")
        print("Type 'initial' to load the initial prompt for similarity calculation")
        print("Type 'decompose' to run task decomposition on the initial prompt")
        print("Type 'tree' to print the current task tree")
        print("-" * 50)
        
        # Show initial GPU info
        if torch.cuda.is_available():
            self.show_gpu_info()
        
        while True:
            try:
                user_input = input("\n👤 You: ").strip()
                
                if not user_input:
                    continue
                
                # Handle special commands
                if user_input.lower() in ['quit', 'exit']:
                    print("👋 Goodbye!")
                    break
                elif user_input.lower() == 'clear':
                    self.clear_history()
                    continue
                elif user_input.lower() == 'history':
                    self.show_history()
                    continue
                elif user_input.lower() == 'gpu':
                    self.show_gpu_info()
                    continue
                elif user_input.lower() == 'cleargpu':
                    self.clear_gpu_memory()
                    continue
                elif user_input.lower() == 'help':
                    print("\nAvailable commands:")
                    print("  quit/exit - End the conversation")
                    print("  clear     - Clear conversation history")
                    print("  history   - Show conversation history")
                    print("  gpu       - Show GPU memory info")
                    print("  cleargpu  - Clear GPU memory cache")
                    print("  help      - Show this help message")
                    print("  initial   - Load the initial prompt for similarity calculation")
                    print("  decompose - Run task decomposition on the initial prompt")
                    print("  tree      - Print the current task tree")
                    continue
                elif user_input.lower() == 'initial':
                    self.initial_prompt = input("Please enter the initial prompt: ")
                    print(f"✅ Initial prompt loaded: {self.initial_prompt[:100]}{'...' if len(self.initial_prompt) > 100 else ''}")
                    continue
                elif user_input.lower() == 'decompose':
                    if not hasattr(self, 'initial_prompt') or self.initial_prompt is None:
                        print("❌ Please load an initial prompt first using the 'initial' command.")
                    else:
                        max_depth = input("Enter max depth for decomposition (default 4): ").strip()
                        try:
                            max_depth = int(max_depth) if max_depth else 4
                        except ValueError:
                            max_depth = 4
                        self.run_task_decomposition(max_depth=max_depth)
                    continue
                elif user_input.lower() == 'tree':
                    if hasattr(self, 'task_tree'):
                        print("\n🌳 Current Task Tree:")
                        print("-" * 50)
                        self.print_task_tree()
                    else:
                        print("❌ No task tree found. Run 'decompose' command first.")
                    continue
                # Generate and display response
                print("\n🤖 Qwen: ", end="", flush=True)
                response = self.generate_response(user_input)
                print(response)
                # self.calculate_entropy_of_similarity_of_current_reply(response)
                
            except KeyboardInterrupt:
                print("\n\n👋 Goodbye!")
                break
            except Exception as e:
                print(f"\nError: {e}")
                continue

def main():
    parser = argparse.ArgumentParser(description="Interactive chat with Qwen model")
    parser.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct", 
                       help="Model name to use (default: Qwen/Qwen2.5-3B-Instruct)")
    parser.add_argument("--temperature", type=float, default=1,
                       help="Temperature for response generation (default: 0.7)")
    parser.add_argument("--max_tokens", type=int, default=1024,
                       help="Maximum tokens for response (default: 1024)")
    
    args = parser.parse_args()
    
    # Create and initialize chat bot
    chatbot = QwenChatBot(model_name=args.model)
    chatbot.load_model()
    
    # Start interactive chat
    chatbot.start_chat()

if __name__ == "__main__":
    main()
