from transformers import pipeline, BitsAndBytesConfig
import transformers
import gc
import json
import os
import torch
from tqdm import tqdm
import pandas as pd
import argparse

# only for phi-4 and qwen3-8b


def load_config(config_path="config.json"):
    """Load configuration from JSON file"""
    try:
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        print(f"Configuration loaded from {config_path}")
        return config
    except FileNotFoundError:
        print(f"Config file {config_path} not found. Using default values.")
        return None
    except json.JSONDecodeError as e:
        print(f"Error parsing config file: {e}. Using default values.")
        return None


class ResponseGenerator:
    def __init__(self, model_name, config):
        self.model_name = model_name
        self.config = config
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device} for model: {model_name}")
        
        if self.device == "cuda":
            print(f"GPU Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
        
        self._init_pipeline()

    def _init_pipeline(self):
        """Initialize text generation pipeline with quantization"""
        model_params = self.config.get("model_params", {})
        
        # Setup quantization config
        quant_config = model_params.get("quantization", {})
        compute_dtype = torch.float16 if quant_config.get("bnb_4bit_compute_dtype") == "float16" else torch.float32
        
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=quant_config.get("load_in_4bit", True),
            bnb_4bit_quant_type=quant_config.get("bnb_4bit_quant_type", "nf4"),
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=quant_config.get("bnb_4bit_use_double_quant", True)
        )
        
        print(f"Loading model with 4-bit quantization...")
        
        # Create pipeline with quantization
        self.pipeline = transformers.pipeline(
            "text-generation",
            model=self.model_name,
            model_kwargs={
                "quantization_config": bnb_config,
                "torch_dtype": "auto"
            },
            device_map="auto",
            token=self.config["file_paths"].get("access_token")
        )
        
        # Set padding side if specified
        if hasattr(self.pipeline.tokenizer, 'padding_side'):
            self.pipeline.tokenizer.padding_side = model_params.get("padding_side", "left")
        
        print(f"Model loaded successfully!")
        
        if self.device == "cuda":
            torch.cuda.empty_cache()
            gc.collect()

    def generate_responses(self, questions, responses_per_question=1):
        """Generate multiple responses for each question individually with progress bar"""
        all_responses = []
        gen_params = self.config["generation_params"]
        min_new_tokens = gen_params.get("min_new_tokens", gen_params.get("min_tokens", 10))
        max_retries = 5
        
        # Calculate total number of responses to generate
        total_responses = len(questions) * responses_per_question
        
        # Progress bar for this model's response generation
        with tqdm(total=total_responses, desc=f"Generating responses for {self.model_name}", unit="response") as pbar:
            for question in questions:
                question_responses = []
                
                # Generate multiple responses for this question
                for response_idx in range(responses_per_question):
                    retry_count = 0
                    response = ""
                    
                    while retry_count < max_retries:
                        try:
                            # Prepare messages
                            system_message = self.config.get("system_message", "You are a helpful assistant.")
                            prompt_template = self.config.get("prompt_template", "{question}")
                            formatted_prompt = prompt_template.format(question=question)
                            
                            messages = [
                                {"role": "system", "content": system_message},
                                {"role": "user", "content": formatted_prompt}
                            ]
                            
                            # Prepare pipeline parameters
                            pipeline_params = {}
                            
                            # Handle max_tokens -> max_new_tokens mapping
                            if "max_tokens" in gen_params:
                                pipeline_params["max_new_tokens"] = gen_params["max_tokens"]

                            if "min_new_tokens" in gen_params:
                                pipeline_params["min_new_tokens"] = gen_params["min_new_tokens"]
                            elif "min_tokens" in gen_params:
                                pipeline_params["min_new_tokens"] = gen_params["min_tokens"]
                            
                            # Direct parameter mappings
                            for param in ["temperature", "top_p", "top_k", "repetition_penalty", "do_sample"]:
                                if param in gen_params:
                                    pipeline_params[param] = gen_params[param]
                            
                            # Always set pad_token_id to avoid warnings
                            pipeline_params["pad_token_id"] = self.pipeline.tokenizer.eos_token_id
                            
                            # Generate response
                            outputs = self.pipeline(messages, **pipeline_params)
                            
                            # Extract the generated response (last message)
                            response = outputs[0]["generated_text"][-1]["content"]
                            
                            # Clean up GPU memory after each generation
                            if self.device == "cuda":
                                torch.cuda.empty_cache()
                                gc.collect()
                            
                            # Check if response is empty or too short
                            response_tokens = self.pipeline.tokenizer.encode(response, add_special_tokens=False)
                            if len(response) == 0 or len(response_tokens) < min_new_tokens:
                                retry_count += 1
                                if retry_count < max_retries:
                                    print(f"\nResponse too short or empty (length: {len(response_tokens)} tokens, required: {min_new_tokens}). Retry {retry_count}/{max_retries}...")
                                    continue
                                else:
                                    print(f"\nMax retries reached. Using last response (length: {len(response_tokens)} tokens).")
                            
                            # Response is valid, break out of retry loop
                            break
                            
                        except Exception as e:
                            retry_count += 1
                            if retry_count >= max_retries:
                                print(f"\nError generating response {response_idx + 1} for question: {question[:50]}...")
                                print(f"Error: {str(e)}")
                                response = "Error generating response"
                                break
                            else:
                                print(f"\nError occurred. Retry {retry_count}/{max_retries}...")
                                continue
                    
                    question_responses.append(response)
                    
                    # Update progress bar
                    pbar.update(1)
                
                # Add all responses for this question to the main list
                all_responses.extend(question_responses)
        
        return all_responses

    def cleanup(self):
        """Clean up model resources"""
        del self.pipeline
        if self.device == "cuda":
            torch.cuda.empty_cache()
            gc.collect()


def read_prompts_from_excel(excel_file):
    """Read prompts from Excel file"""
    try:
        df = pd.read_excel(excel_file)
        if 'question' not in df.columns:
            raise ValueError("Excel file must contain a 'question' column")
        prompts = df['question'].tolist()
        id = df['id'].tolist()
        return prompts, id
    except Exception as e:
        print(f"Error reading prompts from Excel: {str(e)}")
        return [], []


def generate_dataset(model_name, num_responses, config_path="config.json", model_ids_path="model_ids.json", 
                     prompts_file="prompts.xlsx", output_file="responses.parquet"):
    """Generate dataset for a single model and save incrementally per prompt"""
    
    # Load configuration
    config = load_config(config_path)
    if config is None:
        print("Failed to load configuration. Exiting.")
        return
    
    # Load model IDs mapping
    try:
        with open(model_ids_path, 'r', encoding='utf-8') as f:
            model_ids_config = json.load(f)
            model_ids = model_ids_config.get("model_id", {})
    except Exception as e:
        print(f"Error loading model IDs: {str(e)}")
        return
    
    # Get model ID
    if model_name not in model_ids:
        print(f"Model '{model_name}' not found in model_ids.json")
        print(f"Available models: {list(model_ids.keys())}")
        return
    
    model_id = model_ids[model_name]
    print(f"Model ID for {model_name}: {model_id}")
    
    # Read prompts from Excel
    prompts, ids = read_prompts_from_excel(prompts_file)
    if not prompts:
        print("No prompts found. Exiting.")
        return
    
    print(f"Found {len(prompts)} prompts")
    print(f"Generating {num_responses} responses per prompt")
    
    # Get temperature from config
    temperature = config["generation_params"]["temperature"]
    
    # Create output directory structure
    model_safe_name = model_name.replace("/", "_")
    output_dir = os.path.join("responses", model_safe_name)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Saving responses to directory: {output_dir}")
    
    # Initialize response generator
    generator = ResponseGenerator(model_name=model_name, config=config)
    
    # Generate and save responses for each prompt individually
    print(f"Generating and saving responses incrementally...")
    
    for id_val, prompt in zip(ids, prompts):
        print(f"\nProcessing prompt {id_val}")
        
        # Extract year if present
        year = None
        if '2020' in prompt or '2022' in prompt:
            year = '2020' if '2020' in prompt else '2022'
        
        # Generate responses for this single prompt
        prompt_responses = generator.generate_responses([prompt], num_responses)
        
        # Create data for this prompt
        prompt_data = []
        for resp_id, response in enumerate(prompt_responses):
            prompt_data.append({
                "model_id": model_id,
                "prompt_id": id_val,
                "year": year,
                "response_index": resp_id,
                "response": response,
                "prompt": prompt,
                "temperature": temperature,
            })
        
        # Save this prompt's responses immediately
        df = pd.DataFrame(prompt_data)
        df.to_json(f'{output_dir}/dataset.json', orient="records", force_ascii=False, 
                   default_handler=None, lines=True, mode="a")
        
        print(f"Saved {len(prompt_data)} responses")
    
    
    # Cleanup
    generator.cleanup()


def main():
    parser = argparse.ArgumentParser(description='Generate responses from a language model')
    parser.add_argument('--model', type=str, default ='mistralai/Mistral-7B-v0.1',
                        help='Model name (e.g., mistralai/Mistral-7B-v0.1)')
    parser.add_argument('--num_responses', type=int, default=5,
                        help='Number of responses to generate per prompt')
    parser.add_argument('--config', type=str, default='config.json',
                        help='Path to config.json file (default: config.json)')
    parser.add_argument('--model_ids', type=str, default='model_ids.json',
                        help='Path to model_ids.json file (default: model_ids.json)')
    parser.add_argument('--prompts', type=str, default='prompts.xlsx',
                        help='Path to prompts Excel file (default: prompts.xlsx)')
    parser.add_argument('--output', type=str, default='responses.parquet',
                        help='Output parquet file name (default: responses.parquet)')
    
    args = parser.parse_args()
    
    try:
        generate_dataset(
            model_name=args.model,
            num_responses=args.num_responses,
            config_path=args.config,
            model_ids_path=args.model_ids,
            prompts_file=args.prompts,
            output_file=args.output
        )
    except KeyboardInterrupt:
        print("\nExiting...")
    except Exception as e:
        print(f"\nAn error occurred: {str(e)}")
        import traceback
        traceback.print_exc()
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()


if __name__ == "__main__":
    main()