import torch
import tiktoken
import os
import json
import tqdm
import random
import numpy as np
import pandas as pd
import torch.multiprocessing as mp
from model import GPT, inference, eval_loss
from utils import get_info
from dataclasses import dataclass
from typing import List, Dict, Tuple, Any, Optional, Union, Callable
from itertools import chain
from templates import get_accuracy_template
from generate_bios import pronoun_to_fullname
import re
import time
import copy
import concurrent.futures
import platform
import shutil
import threading
from queue import Queue
from collections import defaultdict
import hashlib

# Configure environment
os.environ["TIKTOKEN_CACHE_DIR"] = "./tiktoken_cache"
CACHE_DIR = "/dev/shm/model_cache"

@dataclass
class GPUResource:
    """Represents a GPU resource for distributed inference."""
    gpu_id: int
    model: Any  # The model instance loaded on this GPU

class EvalKitBase:
    def __init__(self, model_path):
        """
        Initialize the evaluation kit with a model path.

        Args:
            model_path: Path to model checkpoint (assumed to be a directory).
        """
        self.model_path = model_path
        self.enc = tiktoken.get_encoding("gpt2")
        self.gpu_resources = {}
        self.n_gpus = torch.cuda.device_count()

        # If on Linux, use /dev/shm to cache the model for subsequent fast loads.
        self.local_model_path = self._get_local_model_path(model_path)

        # Load the model into CPU RAM first with timing.
        start_cpu = time.time()
        try:
            self.base_model_cpu = GPT.from_pretrained(self.local_model_path, device="cpu")
        except Exception as e:
            print(f"Error loading model: {e}")
            exit(0)
        end_cpu = time.time()
        print(f"Loading model to CPU took {end_cpu - start_cpu:.2f} seconds.")

        # Initialize models across available GPUs in parallel (all GPUs loaded concurrently)
        self._initialize_gpu_resources()
        del self.base_model_cpu

    def _get_local_model_path(self, original_model_path: str) -> str:
        """
        For Linux systems, copy the model file to a temporary directory if not already cached.
        This can speed up subsequent loads by using a memory-backed filesystem like tmpfs.

        Args:
            original_model_path (str): The original path of the model checkpoint file.

        Returns:
            str: The local path (potentially in /dev/shm) from where to load the model.
        """
        if platform.system() == "Linux":
            abs_path = os.path.abspath(original_model_path)
            path_hash = hashlib.md5(abs_path.encode()).hexdigest()
            tmp_file_path = os.path.join(CACHE_DIR, path_hash)

            if not os.path.exists(tmp_file_path):
                print(f"Copying model file from {original_model_path} to {tmp_file_path} for faster loading...")
                os.makedirs(os.path.dirname(tmp_file_path), exist_ok=True)
                shutil.copy(abs_path, tmp_file_path)
                time.sleep(2)  # let the filesystem sync
            else:
                print(f"Using cached model file found at {tmp_file_path}.")

            return tmp_file_path
        else:
            return original_model_path

    def _initialize_gpu_resources(self):
        """Initialize models across available GPUs in parallel (including GPU 0)."""
        print(f"Initializing models across {self.n_gpus} GPUs...")

        def load_model_to_gpu(gpu_id):
            # Clone the CPU model and then move it to the specified GPU.
            model_copy = copy.deepcopy(self.base_model_cpu)
            model_copy.to(f"cuda:{gpu_id}")
            return gpu_id, model_copy

        start_gpu = time.time()
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(load_model_to_gpu, gpu_id) for gpu_id in range(self.n_gpus)]
            for future in concurrent.futures.as_completed(futures):
                gpu_id, model_gpu = future.result()
                self.gpu_resources[gpu_id] = GPUResource(gpu_id=gpu_id, model=model_gpu)
        end_gpu = time.time()
        print(f"Loading model to GPUs took {end_gpu - start_gpu:.2f} seconds.")

    
    def _distribute_workload(self, data_size: int, processes_per_gpu: int = 1) -> List[Tuple[int, int, int, int]]:
        """Distribute workload across GPUs and processes.
        
        Args:
            data_size: Total number of samples to process
            processes_per_gpu: Number of processes to run per GPU
            
        Returns:
            List of (process_id, gpu_id, start_idx, end_idx) tuples
        """
        total_processes = self.n_gpus * processes_per_gpu
        chunk_size = (data_size + total_processes - 1) // total_processes # averaged workload per process
        
        distribution = []
        for gpu_id in range(self.n_gpus):
            for proc_idx in range(processes_per_gpu):
                process_id = gpu_id * processes_per_gpu + proc_idx
                start_idx = process_id * chunk_size
                end_idx = min(start_idx + chunk_size, data_size)
                
                if start_idx < data_size:
                    distribution.append((process_id, gpu_id, start_idx, end_idx))
        
        return distribution

    def _process_chunk_thread(self, process_func: Callable[[GPT, object], Dict], process_id: int, gpu_id: int,
                            texts: List[str], start_idx: int, end_idx: int, return_dict: dict, 
                            lock: threading.Lock, progress_queue: Queue):
        model = self.gpu_resources[gpu_id].model
        results = []
        for i in range(start_idx, end_idx):
            results.append(process_func(model, texts[i]))
            progress_queue.put(1)
        
        with lock:
            return_dict[process_id] = results

    def _process_func_wrapper(self, process_func: Callable[[GPT, object], Dict], queries: List[object],
                            processes_per_gpu: int, tqdm_desc: str = "Processing texts"):
        distribution = self._distribute_workload(len(queries), processes_per_gpu)

        return_dict = {}
        lock = threading.Lock()
        progress_queue = Queue()

        threads = []
        pbar = tqdm.tqdm(total=len(queries), desc=tqdm_desc)
        print(f"Start {len(distribution)} threads")

        start = time.time()
        for process_id, gpu_id, start_idx, end_idx in distribution:
            t = threading.Thread(
                target=self._process_chunk_thread,
                args=(process_func, process_id, gpu_id, queries, start_idx, end_idx, return_dict, lock, progress_queue)
            )
            threads.append(t)
            t.start()
            print(f"Start thread {process_id} on GPU {gpu_id} at {time.time() - start:.2f} seconds")

        completed = 0
        while completed < len(queries):
            progress_queue.get()
            completed += 1
            pbar.update(1)
        pbar.close()

        for t in threads:
            t.join()

        collected_results = []
        for process_id in sorted(return_dict.keys()):
            collected_results.extend(return_dict[process_id])

        return collected_results
    
    
class EvalKit(EvalKitBase):
    def __init__(self, model_path):
        super().__init__(model_path)
    
    def process_qa(self, model: GPT, qa: str):
        prompt = qa.split("A:")[0] + "A:"
        question = qa.split("A:")[0].split("Q:")[1].strip()
        question_type, person = get_info(question)
        golden_answer = qa.split("A:")[1].strip()
        
        generated_answer = inference(
            model, 
            prompt, 
            tokenizer=self.enc, 
            max_new_tokens=10, 
            stop_token=198, 
            temperature=0
        ).strip()
        
        is_correct = (golden_answer == generated_answer)
        
        return {
            "prompt": prompt,
            "golden_answer": golden_answer,
            "generated_answer": generated_answer,
            "question_type": question_type,
            "person": person,
            "correct": is_correct
        }
    
    def eval_SFT_by_text(self, text_path: str, output_path: str = None, 
                        first_n: int = -1, processes_per_gpu: int = 4):
        """Evaluate SFT performance using a text file of QA pairs.
        
        Args:
            text_path: Path to text file with QA pairs
            output_path: Path to save results (optional)
            first_n: Number of samples to process (-1 for all)
            processes_per_gpu: Number of processes to run per GPU
            
        Returns:
            Dict containing accuracy and detailed results
        """
        # Load data
        with open(text_path, "r", encoding="utf-8") as f:
            qa_samples = f.readlines()
        
        qa_samples = qa_samples[:first_n] if first_n != -1 else qa_samples
       
        all_results = self._process_func_wrapper(self.process_qa, qa_samples, processes_per_gpu, tqdm_desc="Processing QA pairs")
        # Convert results to DataFrame for analysis
        df = pd.DataFrame(all_results)
        print(df.describe())
        # Generate detailed results
        summary = {
            "accuracy": df.correct.mean(),
            "total_samples": len(df),
            "correct_samples": df.correct.sum(),
            "detail_result": df.groupby(["question_type"]).correct.mean().to_dict()
        }
        print(summary)
        # Save results if output path is provided
        if output_path:
            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(summary, f, indent=4)
            df.to_csv(output_path.replace(".json", ".csv"), index=False)
        
        return summary, df
    
    def process_pretrain(self, model: GPT, description: str, profile: dict):
        """Process a single pretraining sample, calculating different loss components.
        
        Args:
            model: The GPT model
            description: Text description to evaluate
            profile: The profile dictionary containing all fields
            
        Returns:
            tuple: (total_loss, keyword_loss, syntax_loss)
        """
        # Identify knowledge keywords (all values except *pronoun and *name fields)
        knowledge_keywords = []
        for key, value in profile.items():
            if not any(x in key for x in ['pronoun', 'name']) and isinstance(value, str):
                knowledge_keywords.append(value)
                knowledge_keywords.append(" "+value)
        
        # Calculate losses with keyword separation
        total_loss, keyword_loss, syntax_loss = eval_loss(
            model, 
            description, 
            tokenizer=self.enc,
            special_texts=[[kw for kw in knowledge_keywords if kw]]  # Ensure non-empty keywords
        )
        return (total_loss, keyword_loss, syntax_loss)
    
    def _pretrain_process_wrapper(self, model, x):
        """Wrapper function that can be pickled for multiprocessing"""
        return self.process_pretrain(model, x[0], x[1])

    def _process_profile_accuracy(self, model: GPT, profile: dict):
        """Process a single profile to evaluate generation accuracy for each attribute."""
        results = {}
        # Keys corresponding to templates suitable for accuracy evaluation
        accuracy_keys = ["birth_date", "birth_city", "university", "major", "employer", "company_city"]
        
        # Prepare attribute values, potentially replacing pronouns with full name for prompts
        # Use pronoun_to_fullname logic for consistent prompting if needed, but templates should handle it.
        # Let's try using the name directly in the prompt using pronoun_to_fullname first
        prompt_values = pronoun_to_fullname(profile)

        for key in accuracy_keys:
            try:
                template = get_accuracy_template(key)
                golden_answer = profile[key]

                # Extract the prompt part by removing the placeholder {key}. at the end
                prompt_base = re.sub(r'\{' + key + r'\}\.$', '', template)
                
                # Format the prompt using potentially name-substituted values
                prompt = prompt_base.format(**prompt_values).strip()
                # use remotepdb
                generated_text = inference(
                    model,
                    prompt,
                    tokenizer=self.enc,
                    max_new_tokens=30, # Allow space for longer attributes like university names
                    stop_token=13, # Stop at newline
                    temperature=0
                ).strip()
                
                # Basic cleaning: remove potential trailing period
                if generated_text.endswith('.'):
                    generated_answer = generated_text[:-1].strip()
                else:
                    generated_answer = generated_text

                is_correct = (golden_answer == generated_answer)
                results[f"{key}_correct"] = is_correct
                results[f"{key}_generated"] = generated_answer # Store for potential debugging
                results[f"{key}_golden"] = golden_answer       # Store for potential debugging

            except Exception as e:
                #print traceback
                import traceback
                print(traceback.format_exc())
                print(f"Error processing profile {profile.get('full_name', 'N/A')} for key {key}: {e}")
                results[f"{key}_correct"] = False
                results[f"{key}_generated"] = "ERROR"
                results[f"{key}_golden"] = profile.get(key, "N/A")


        return results

    def _accuracy_process_wrapper(self, model: GPT, profile: dict):
        """Wrapper function that can be pickled for multiprocessing"""
        return self._process_profile_accuracy(model, profile)

    def eval_pretrain_by_profile(self, data_path: str, output_path: str = None, every_n=1, processes_per_gpu: int = 4):
        """Evaluate pretraining performance using profile data with knowledge/syntax separation.
        
        Args:
            data_path: Path to the dir of pretraining data (profiles.jsonl and count.json)
            output_path: Where to save results CSV
            every_n: Evaluate every nth profile
            processes_per_gpu: Number of parallel processes per GPU
            
        Returns:
            DataFrame: Results with total_loss, keyword_loss, syntax_loss columns
        """
        # Load profiles and counts
        with open(os.path.join(data_path, "profiles.jsonl"), "r", encoding="utf-8") as f:
            profiles = [json.loads(line) for line in f]
        profiles = profiles[::every_n]
        
        counts = json.load(open(os.path.join(data_path, "count.json"), "r", encoding="utf-8"))
        cum_counts_ratio = np.cumsum(counts) / sum(counts)
        counts = counts[::every_n]
        cum_counts_ratio = cum_counts_ratio[::every_n]
        profiles = profiles[:len(counts)]
        
        from generate_bios import generate_perturbed_description

        # Generate perturbed descriptions
        perturbed_descriptions = []
        for profile in profiles:
            perturbed_descriptions.append(generate_perturbed_description(profile))
        # Process with knowledge/syntax separation
        all_results = self._process_func_wrapper(
            self._pretrain_process_wrapper,
            list(zip(perturbed_descriptions, profiles)),
            processes_per_gpu,
            tqdm_desc="Processing descriptions with knowledge/syntax separation"
        )

        # Prepare final results with all loss components
        final_results = []
        for i, (result, count) in enumerate(zip(all_results, counts)):
            total_loss, keyword_loss, syntax_loss = result
            final_results.append({
                "total_loss": total_loss,
                "keyword_loss": keyword_loss,
                "syntax_loss": syntax_loss,
                "count": count,
                "cum_count_ratio": cum_counts_ratio[i],
                "rank": i * every_n
            })
        
        df = pd.DataFrame(final_results)
        if output_path:
            df.to_csv(output_path, index=False)
        print(df.describe())
        return df

    def eval_pretrain_by_accuracy(self, data_path: str, output_path: str = None, every_n=1, processes_per_gpu: int = 4):
        """Evaluate pretraining performance using profile data by checking generation accuracy.

        Args:
            data_path: Path to the dir of pretraining data (profiles.jsonl and count.json)
            output_path: Where to save results CSV (optional)
            every_n: Evaluate every nth profile
            processes_per_gpu: Number of parallel processes per GPU

        Returns:
            DataFrame: Results with accuracy columns for each attribute, count, rank, etc.
        """
        # Load profiles and counts
        start = time.time()
        profiles_path = os.path.join(data_path, "profiles.jsonl")
        counts_path = os.path.join(data_path, "count.json")
        if not os.path.exists(profiles_path):
             raise FileNotFoundError(f"Profile file not found: {profiles_path}")
        if not os.path.exists(counts_path):
             raise FileNotFoundError(f"Count file not found: {counts_path}")

        with open(profiles_path, "r", encoding="utf-8") as f:
            profiles = [json.loads(line) for line in f]
        profiles = profiles[::every_n]

        with open(counts_path, "r", encoding="utf-8") as f:
            counts = json.load(f)
        cum_counts_ratio = (np.cumsum(counts) / sum(counts)) if sum(counts) > 0 else np.zeros_like(counts, dtype=float)
        counts = counts[::every_n]
        cum_counts_ratio = cum_counts_ratio[::every_n]
        # Ensure profiles and counts align after sampling
        min_len = min(len(profiles), len(counts))
        profiles = profiles[:min_len]
        counts = counts[:min_len]
        cum_counts_ratio = cum_counts_ratio[:min_len]
        end = time.time()
        print(f"Prepare data took {end - start:.2f} seconds")
        # Process profiles for accuracy
        accuracy_results = self._process_func_wrapper(
            self._accuracy_process_wrapper,
            profiles,
            processes_per_gpu,
            tqdm_desc="Processing profiles for accuracy"
        )
        # accuracy_results should now be a flat list of dictionaries, one per profile

        # Prepare final results DataFrame
        final_results = []
        accuracy_keys = ["birth_date", "birth_city", "university", "major", "employer", "company_city"]
        for i, (result, count) in enumerate(zip(accuracy_results, counts)):
            profile_data = {
                "count": count,
                "cum_count_ratio": cum_counts_ratio[i],
                "rank": i * every_n,
                "full_name": profiles[i].get("full_name", "N/A") # Add name for reference
            }
            # Add accuracy results for each key
            for key in accuracy_keys:
                 profile_data[f"{key}_correct"] = result.get(f"{key}_correct", False)
                 # Optionally add generated/golden answers for debugging
                 profile_data[f"{key}_generated"] = result.get(f"{key}_generated", "N/A")
                 profile_data[f"{key}_golden"] = result.get(f"{key}_golden", "N/A")

            # Calculate overall accuracy for this profile
            correct_count = sum(result.get(f"{key}_correct", False) for key in accuracy_keys)
            profile_data["overall_accuracy"] = correct_count / len(accuracy_keys) if accuracy_keys else 0

            final_results.append(profile_data)

        df = pd.DataFrame(final_results)

        # Calculate and print summary statistics
        summary_stats = {}
        summary_stats["total_profiles_evaluated"] = len(df)
        summary_stats["overall_accuracy_mean"] = df["overall_accuracy"].mean()
        for key in accuracy_keys:
            summary_stats[f"{key}_accuracy_mean"] = df[f"{key}_correct"].mean()

        print("\n--- Pretrain Accuracy Evaluation Summary ---")
        for stat, value in summary_stats.items():
            print(f"{stat}: {value:.4f}")
        print("-------------------------------------------\n")
        print("DataFrame Description:")
        print(df.describe())


        if output_path:
            # Save detailed results
            df.to_csv(output_path, index=False)
            print(f"Detailed accuracy results saved to {output_path}")
            # Save summary stats
            summary_output_path = output_path.replace(".csv", "_summary.json")
            with open(summary_output_path, "w", encoding="utf-8") as f:
                 json.dump(summary_stats, f, indent=4)
            print(f"Accuracy summary saved to {summary_output_path}")


        return df

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Example usage function
def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Evaluate model performance")
    parser.add_argument("-m", "--model_path", type=str, help="Path to model checkpoint")
    parser.add_argument("-i", "--input_path", type=str, help="Path to input data")
    parser.add_argument("-o", "--output_path", type=str, help="Path to save results")
    parser.add_argument("-n", "--first_n", type=int, default=-1, help="Process only first N samples")
    parser.add_argument("-e", "--every_n", type=int, default=1, help="Process every N samples")
    parser.add_argument("-p", "--processes_per_gpu", type=int, default=1, help="Number of processes per GPU")
    # Updated task choices
    parser.add_argument("-t", "--task", type=str, choices=["sft", "pretrain", "pta"], default="sft", help="Evaluation task: sft (QA), pretrain (loss), accuracy (pretrain generation)")
    parser.add_argument("-s", "--random_seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()
    
    seed_everything(args.random_seed)
    
    if args.processes_per_gpu > 1:
        print("Warning: processes_per_gpu > 1 will not provide any speedup for current implementation")
        
    # Initialize evaluation kit
    evaluator = EvalKit(model_path=args.model_path)
    
    # Run evaluation based on task
    if args.task == "sft":
        assert args.every_n == 1, "every_n must be 1 for SFT evaluation"
        assert args.first_n == -1 or args.first_n > 0, "first_n must be -1 or positive for SFT"
        evaluator.eval_SFT_by_text(
            text_path=args.input_path,
            output_path=args.output_path,
            first_n=args.first_n,
            processes_per_gpu=args.processes_per_gpu
        )
    elif args.task == "pretrain":
        assert args.first_n == -1, "first_n must be -1 (not applicable) for pretrain loss evaluation"
        evaluator.eval_pretrain_by_profile(
            data_path=args.input_path,
            output_path=args.output_path,
            processes_per_gpu=args.processes_per_gpu,
            every_n=args.every_n
        )
    # Added accuracy task
    elif args.task == "pta":
        assert args.first_n == -1, "first_n must be -1 (not applicable) for pretrain accuracy evaluation"
        evaluator.eval_pretrain_by_accuracy(
            data_path=args.input_path,
            output_path=args.output_path,
            processes_per_gpu=args.processes_per_gpu,
            every_n=args.every_n
        )

if __name__ == "__main__":
    main()