import pandas as pd
import torch
import argparse
from tqdm import tqdm
from vllm import LLM, SamplingParams
import os
import re


class CSVInference:
    def __init__(self, models_dir: str, csv_path: str, output_dir: str):
        self.models_dir = models_dir
        self.csv_path = csv_path
        self.output_dir = output_dir
        self.df = self.load_csv()
        self.model_paths = self.discover_models()
        os.makedirs(output_dir, exist_ok=True)
        
    def discover_models(self) -> list:
        """Discover all model directories in the models folder"""
        model_paths = []
        
        if not os.path.exists(self.models_dir):
            raise ValueError(f"Models directory does not exist: {self.models_dir}")
        
        # Look for subdirectories that contain config.json (indicating a model)
        for item in os.listdir(self.models_dir):
            item_path = os.path.join(self.models_dir, item)
            if os.path.isdir(item_path):
                config_path = os.path.join(item_path, 'config.json')
                if os.path.exists(config_path):
                    model_paths.append(item_path)
                    print(f"Found model: {item}")
        
        if not model_paths:
            raise ValueError(f"No valid models found in {self.models_dir}. Each model directory should contain config.json")
        
        print(f"Discovered {len(model_paths)} models")
        return model_paths
        
    def load_csv(self) -> pd.DataFrame:
        """Load CSV dataset"""
        df = pd.read_csv(self.csv_path)
        
        # Ensure required columns exist - case insensitive check
        df.columns = df.columns.str.strip()  # Remove any whitespace
        
        # Map column names (case insensitive)
        column_mapping = {}
        for col in df.columns:
            col_lower = col.lower()
            if col_lower in ['question']:
                column_mapping[col] = 'Question'
            elif col_lower in ['answer']:
                column_mapping[col] = 'Answer'
            elif col_lower in ['no', 'number', 'id']:
                column_mapping[col] = 'No'
        
        # Rename columns to standard format
        df = df.rename(columns=column_mapping)
        
        # Ensure required columns exist
        required_columns = ['Question']
        for col in required_columns:
            if col not in df.columns:
                raise ValueError(f"Required column '{col}' not found in CSV. Available columns: {list(df.columns)}")
        
        # Ensure Answer column exists
        if 'Answer' not in df.columns:
            df['Answer'] = ''
        
        # Ensure No column exists (for indexing)
        if 'No' not in df.columns:
            df['No'] = range(1, len(df) + 1)
            
        return df
    
    def create_prompt(self, question: str) -> str:
        """Create prompt for  model"""
        # Using the same format as your original script
        prompt = f"<s> [INST] {question} [/INST] Output:"
        return prompt
    
    def flip_charlevel(self, sentence):
        """Flip characters in the sentence"""
        return sentence[::-1]
    
    def flip_wordlevel(self, sentence):
        """Flip words in the sentence with proper punctuation handling"""

        sentence = re.sub(r"(?<!\s)([!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~])(?=(\s|$))", r" \1", sentence)
        sentence = re.sub(r"([\(\[\{])", r"\1 ", sentence)
        
        words = sentence.split()
        
        words.reverse()
        
        return " ".join(words)
    
    def flip_response_based_on_model(self, model_answer: str, model_name: str) -> str:
        """
        Modify the model response based on the model name ending
        """
        model_name_lower = model_name.lower()
        
        if model_name_lower.endswith('charflipped'):

            return self.flip_charlevel(model_answer)
            
        elif model_name_lower.endswith('wordflipped'):

            return self.flip_wordlevel(model_answer)
                
        return model_answer
    
    def clean_generated_text(self, generated_text: str, prompt: str) -> str:
        """Clean the generated text by removing prompt and unwanted tokens"""
        # Remove prompt
        generated_text = generated_text.replace(prompt, "").strip()
        
        # Remove unwanted tokens
        unwanted_tokens = ['<s>', '[INST]', '[/INST]', '</s>', '<|im_end|>', '<|endoftext|>']
        for token in unwanted_tokens:
            generated_text = generated_text.replace(token, '')
        
        # Check if there's an extra space at the start    
        if generated_text.startswith(' '):  
            generated_text = generated_text[1:]
        
        return generated_text.strip()
    
    def run_inference_single_model(self, model_path: str, batch_size: int = 8, max_new_tokens: int = 1024, 
                                   temperature: float = 0.5, top_p: float = 0.95):
        """Run inference on a single model"""
        
        model_name = os.path.basename(model_path)
        print(f"\nProcessing model: {model_name}")
        print(f"Model path: {model_path}")
        
        # Determine tensor parallel size based on available GPUs
        gpu_count = torch.cuda.device_count()
        
        if gpu_count >= 4:
            tensor_parallel_size = 4
        elif gpu_count >= 2:
            tensor_parallel_size = 2
        else:
            tensor_parallel_size = 1
            
        print(f"Using tensor_parallel_size={tensor_parallel_size}")
        
        try:
            # Initialize vLLM model
            llm = LLM(
                model=model_path,
                trust_remote_code=False,
                tensor_parallel_size=tensor_parallel_size,
                gpu_memory_utilization=0.9,
                dtype="bfloat16",
                max_model_len=4096,
            )
        except Exception as e:
            print(f"Error loading model {model_name}: {e}")
            if tensor_parallel_size > 1:
                print("Retrying with single GPU...")
                try:
                    llm = LLM(
                        model=model_path,
                        trust_remote_code=False,
                        tensor_parallel_size=1,
                        gpu_memory_utilization=0.9,
                        dtype="bfloat16",
                    )
                except Exception as e2:
                    print(f"Failed: {e2}")
                    return None
            else:
                return None
        
        prompts = []
        for idx, row in self.df.iterrows():
            question = str(row['Question'])
            prompt = self.create_prompt(question)
            prompts.append(prompt)
        
        # Set up sampling parameters
        sampling_params = SamplingParams(
            temperature=temperature,
            max_tokens=max_new_tokens,
            top_p=top_p,
            stop=["</s>", "<|im_end|>", "<|endoftext|>"],
        )
        
        # Run inference in batches
        results = []
        
        print(f"Running inference on {len(prompts)} questions...")
        
        for i in tqdm(range(0, len(prompts), batch_size), desc=f"Processing {model_name}"):
            batch_prompts = prompts[i:i+batch_size]
            batch_rows = self.df.iloc[i:i+batch_size]
            
            try:
                # Generate responses
                outputs = llm.generate(batch_prompts, sampling_params)
                
                for j, output in enumerate(outputs):
                    raw_response = output.outputs[0].text
                    prompt = batch_prompts[j]
                    row = batch_rows.iloc[j]
                    
                    # Clean the response
                    cleaned_response = self.clean_generated_text(raw_response, prompt)
                    
                    # Apply flipping logic based on model name
                    flipped_response = self.flip_response_based_on_model(
                        cleaned_response, 
                        model_name
                    )
                    
                    # Create result row with specified columns only
                    result = {
                        "No": row['No'],
                        "Question": row['Question'], 
                        "Answer": row['Answer'],
                        "ModelResponse": cleaned_response,
                        "ModelAnswer": flipped_response,
                    }
                    results.append(result)
                    
            except Exception as e:
                print(f"Error in batch {i//batch_size}: {e}")
                
                # Add error entries for this batch
                for j in range(len(batch_prompts)):
                    row = batch_rows.iloc[j]
                    result = {
                        "No": row['No'],
                        "Question": row['Question'],
                        "Answer": row['Answer'], 
                        "ModelResponse": "ERROR",
                        "ModelAnswer": "ERROR",
                    }
                    results.append(result)
        
        # Clean up
        del llm
        torch.cuda.empty_cache()
        
        # Save results for this model
        self.save_results(results, model_name)
        
        return results
    
    def run_inference_all_models(self, batch_size: int = 8, max_new_tokens: int = 1024, 
                                temperature: float = 0.5, top_p: float = 0.95):
        """Run inference on all discovered models"""
        
        print(f"Starting inference on {len(self.model_paths)} models...")
        
        all_results = {}
        successful_models = 0
        failed_models = 0
        
        for model_path in self.model_paths:
            model_name = os.path.basename(model_path)
            
            try:
                results = self.run_inference_single_model(
                    model_path=model_path,
                    batch_size=batch_size,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p
                )
                
                if results:
                    all_results[model_name] = results
                    successful_models += 1
                    print(f"Successfully processed {model_name}")
                else:
                    failed_models += 1
                    print(f"Failed to process {model_name}")
                    
            except Exception as e:
                failed_models += 1
                print(f"Error processing {model_name}: {e}")
        
        print(f"\n=== SUMMARY ===")
        print(f"Total models: {len(self.model_paths)}")
        print(f"Successful: {successful_models}")
        print(f"Failed: {failed_models}")
        print(f"Results saved in: {self.output_dir}")
        
        return all_results
    
    def save_results(self, results: list, model_name: str):
        """Save results to CSV in output directory"""
        df_results = pd.DataFrame(results)
        
        # Create output filename with model name
        output_filename = f"Output_{model_name}.csv"
        output_path = os.path.join(self.output_dir, output_filename)
        
        # Save to CSV
        df_results.to_csv(output_path, index=False)
        print(f"Results saved to: {output_path}")
        
        # Print summary
        print(f"\nSummary:")
        print(f"Total questions processed: {len(results)}")
        print(f"Model: {model_name}")
        print(f"Output file: {output_filename}")
        print(f"Columns: {list(df_results.columns)}")

def main():
    parser = argparse.ArgumentParser(description="Run inference on CSV data using vLLM for multiple models")
    parser.add_argument("--models_dir", type=str, required=True, 
                       help="Path to the directory containing multiple model folders")
    parser.add_argument("--csv_path", type=str, required=True, 
                       help="Path to the input CSV file")
    parser.add_argument("--output_dir", type=str, default="inference_results", 
                       help="Output directory for results")
    parser.add_argument("--batch-size", type=int, default=8, 
                       help="Batch size for inference")
    parser.add_argument("--max-tokens", type=int, default=1024, 
                       help="Maximum number of new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.0, 
                       help="Temperature for sampling")
    parser.add_argument("--top-p", type=float, default=0.95, 
                       help="Top-p for nucleus sampling")
    
    args = parser.parse_args()
    
    # Initialize inference engine
    inference_engine = CSVInference(
        models_dir=args.models_dir,
        csv_path=args.csv_path,
        output_dir=args.output_dir
    )
    
    # Run inference on all models
    results = inference_engine.run_inference_all_models(
        batch_size=args.batch_size,
        max_new_tokens=args.max_tokens,
        temperature=args.temperature,
        top_p=args.top_p
    )
    
    if results:
        print("All inference completed successfully!")
    else:
        print("No models were processed successfully!")

if __name__ == "__main__":
    main()