import torch
import logging
from tqdm import trange
from datasets import load_dataset, config
from pprint import pformat
# from lm_eval import simple_evaluate
# from lm_eval.api.model import LM
# from lm_eval import evaluator
# from lm_eval.models.huggingface import HFLM
# from lm_eval.models import get_model

from lm_eval import simple_evaluate
from lm_eval.models.huggingface import HFLM

from lm_eval.tasks import get_task_dict
from torch.profiler import profile, ProfilerActivity
from rich.console import Console
import sys
from pathlib import Path
# from utils import dispatch, move_to_cpu


logger = logging.getLogger(__name__)


def eval_wikitext(model, tokenizer, dataset=None):
    model.eval()
    device = next(model.parameters()).device

    if not dataset:
        dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    # https://github.com/mit-han-lab/llm-awq/blob/b42d2f2871df64c300cd3a601fda7e3025f5348f/awq/entry.py#L301
    # https://github.com/IST-DASLab/gptq/blob/2d65066eeb06a5c9ff5184d8cebdf33662c67faf/llama.py#L206
    dataset = tokenizer("\n\n".join(dataset['text']), return_tensors="pt")

    model.seqlen = 2048
    dataset = dataset.input_ids.to(device)
    nsamples = dataset.numel() // model.seqlen
    nlls = []

    total_tokens = 0
    total_time = 0

    # Process the tokenized text in chunks (batches)
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof, trange(nsamples, desc=f'Perplexity: 0.00', ncols=80) as pbar:
        for i in pbar:
            batch = dataset[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(device)

            torch.cuda.synchronize()
            start_time = torch.cuda.Event(enable_timing=True)
            end_time = torch.cuda.Event(enable_timing=True)

            start_time.record()
            with torch.no_grad():
                lm_logits = model(batch).logits

            end_time.record()
            torch.cuda.synchronize()
            elapsed_ms = start_time.elapsed_time(end_time)  # milliseconds

            batch_tokens = batch.numel()
            tokens_per_sec = batch_tokens / (elapsed_ms / 1000.0)

            total_tokens += batch_tokens
            total_time += (elapsed_ms / 1000.0)


            shift_logits = lm_logits[:, :-1, :].contiguous().float()
            shift_labels = dataset[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:]
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            neg_log_likelihood = loss.float() * model.seqlen
            nlls.append(neg_log_likelihood)

            ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)).item()

            pbar.set_description(f"PPL: {ppl:.2f} | Batch {i:3d} throughput: {tokens_per_sec:.2f}")
    overall_tokens_per_sec = total_tokens / total_time
    print(f"Overall throughput: {overall_tokens_per_sec:.2f} tokens/sec")
    return ppl


def eval_c4(model, tokenizer, dataset, max_length=512):
    model.eval()
    device = next(model.parameters()).device
    num_examples = 364608
    # print(f"Number of examples: {num_examples}")
    
    total_log_likelihood = 0
    total_tokens = 0
    perplexity = 0
    dataset_iter = iter(dataset)
    
    with torch.no_grad(), trange(0, 2000, ncols=80) as pbar:
        i = 0
        for i in pbar:
            batch = next(dataset_iter)
            batch_text = batch['text']
            inputs = tokenizer(
                batch_text,
                return_tensors="pt",
                max_length=max_length,
                truncation=True,
                padding=True
            )
            inputs = {key: value.to(device) for key, value in inputs.items()}
            
            # Shift labels for causal language modeling
            labels = inputs["input_ids"].clone()
            
            outputs = model(**inputs, labels=labels)
            log_likelihood = outputs.loss.item() * labels.size(1)  # Loss is averaged, so multiply by sequence length
                
            total_log_likelihood += log_likelihood
            total_tokens += labels.size(1)
            
            perplexity = torch.exp(total_log_likelihood / total_tokens).item()
            pbar.set_description(f"Perplexity: {perplexity:.2f}")
    return perplexity


def eval_tasks_lower_version(model, tokenizer, tasks):
    # Quick fix - set proper encoding and cache handling
    import os
    os.environ["PYTHONIOENCODING"] = "utf-8"
    os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
    os.environ["HF_DATASETS_OFFLINE"] = "0"

    # Clear any problematic cache
    import shutil
    cache_dirs = [
        os.path.expanduser("~/.cache/huggingface/datasets/piqa"),
        "/tmp/hf_datasets_cache/piqa"
    ]
    for cache_dir in cache_dirs:
        if os.path.exists(cache_dir):
            shutil.rmtree(cache_dir)
    if not hasattr(config, 'HF_DATASETS_TRUST_REMOTE_CODE') or not config.HF_DATASETS_TRUST_REMOTE_CODE:
        config.HF_DATASETS_TRUST_REMOTE_CODE = True
    
    # Set environment variables to help with dataset loading
    import os
    os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
    
    device = next(model.parameters()).device
    tasks = tasks.copy()
    eval_results = {}
        
    if 'wikitext' in tasks:
        try:
            print("Evaluating wikitext...")
            result_wikitext = eval_wikitext(model, tokenizer)
            eval_results.update({'wikitext': result_wikitext})
            print(f"Wikitext completed: {result_wikitext}")
        except Exception as e:
            print(f"Wikitext failed: {str(e)}")
            eval_results.update({'wikitext': {'error': str(e), 'status': 'FAILED'}})
        tasks.remove('wikitext')

    # Evaluate each remaining task individually
    if tasks:
        from lm_eval.models.gpt2 import HFLM
        from lm_eval.tasks import get_task
        
        for task_name in tasks:
            print(f"\n{'='*50}")
            print(f"Starting evaluation of task: {task_name}")
            print(f"{'='*50}")
            
            try:
                # Clear memory before each task
                torch.cuda.empty_cache()
                
                # Check memory status
                if torch.cuda.is_available():
                    allocated = torch.cuda.memory_allocated() / 1024**3
                    print(f"GPU Memory before {task_name}: {allocated:.2f}GB")
                
                # Create fresh model wrapper for each task
                print(f"Creating model wrapper for {task_name}...")
                lm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=1)
                
                # Load single task
                print(f"Loading task class for {task_name}...")
                task_class = get_task(task_name)
                task_obj = task_class()
                task_dict = {task_name: task_obj}
                print(f"Task {task_name} loaded successfully")
                
                # Evaluate single task with minimal resources
                print(f"Starting evaluation of {task_name}...")
                results = evaluator.evaluate(
                    lm=lm,
                    task_dict=task_dict,
                    num_fewshot=0,
                    bootstrap_iters=1,  # Reduced from 2 to 1
                    limit=5  # Very small limit to avoid memory issues
                )
                
                # Extract results
                if 'results' in results:
                    task_result = results['results'].get(task_name, {})
                else:
                    task_result = results.get(task_name, {})
                
                eval_results[task_name] = task_result
                print(f"✅ Task {task_name} completed successfully")
                print(f"Result: {task_result}")
                
            except RuntimeError as e:
                if "CUDA error" in str(e) or "out of memory" in str(e):
                    print(f"❌ CUDA error in task {task_name}: {str(e)}")
                    print(f"Attempting CPU fallback for {task_name}...")
                    
                    try:
                        # Move model to CPU for this task
                        print(f"Moving model to CPU for {task_name}...")
                        model_cpu = move_to_cpu(model)
                        
                        # Create CPU model wrapper
                        lm_cpu = HFLM(pretrained=model_cpu, tokenizer=tokenizer, batch_size=1)
                        
                        # Re-create task dict for CPU
                        task_class = get_task(task_name)
                        task_obj = task_class()
                        task_dict = {task_name: task_obj}
                        
                        print(f"Running {task_name} on CPU...")
                        results = evaluator.evaluate(
                            lm=lm_cpu,
                            task_dict=task_dict,
                            num_fewshot=0,
                            bootstrap_iters=1,
                            limit=3  # Even smaller limit for CPU
                        )
                        
                        # Extract results
                        if 'results' in results:
                            task_result = results['results'].get(task_name, {})
                        else:
                            task_result = results.get(task_name, {})
                        
                        eval_results[task_name] = task_result
                        eval_results[task_name]['note'] = 'Evaluated on CPU due to CUDA error'
                        print(f"✅ Task {task_name} completed on CPU")
                        print(f"Result: {task_result}")
                        
                        # Move model back to GPU for next task
                        print("Moving model back to GPU...")
                        model = model.cuda()
                        torch.cuda.empty_cache()
                        
                    except Exception as cpu_e:
                        print(f"❌ CPU fallback also failed for {task_name}: {str(cpu_e)}")
                        eval_results[task_name] = {
                            'error': f'GPU: {str(e)}, CPU: {str(cpu_e)}', 
                            'status': 'FAILED'
                        }
                else:
                    print(f"❌ Non-CUDA error in task {task_name}: {str(e)}")
                    eval_results[task_name] = {'error': str(e), 'status': 'FAILED'}
                    
            except Exception as e:
                print(f"❌ Unexpected error in task {task_name}: {str(e)}")
                eval_results[task_name] = {'error': str(e), 'status': 'FAILED'}
            
            finally:
                # Clean up after each task
                torch.cuda.empty_cache()
                if torch.cuda.is_available():
                    allocated = torch.cuda.memory_allocated() / 1024**3
                    print(f"GPU Memory after {task_name}: {allocated:.2f}GB")
    
    # Print summary
    print(f"\n{'='*60}")
    print("EVALUATION SUMMARY")
    print(f"{'='*60}")
    
    successful_tasks = []
    failed_tasks = []
    
    for task_name, result in eval_results.items():
        if isinstance(result, dict) and ('error' in result or 'status' in result):
            failed_tasks.append(task_name)
            print(f"❌ {task_name}: FAILED")
        else:
            successful_tasks.append(task_name)
            print(f"✅ {task_name}: SUCCESS")
    
    print(f"\nSuccessful: {len(successful_tasks)}/{len(eval_results)}")
    print(f"Failed: {len(failed_tasks)}/{len(eval_results)}")
    
    if successful_tasks:
        print(f"Successful tasks: {', '.join(successful_tasks)}")
    if failed_tasks:
        print(f"Failed tasks: {', '.join(failed_tasks)}")
    
    return eval_results

def eval_tasks(model, tokenizer, tasks):
    if not hasattr(config, 'HF_DATASETS_TRUST_REMOTE_CODE') or not config.HF_DATASETS_TRUST_REMOTE_CODE:
        config.HF_DATASETS_TRUST_REMOTE_CODE = True
    
    # Map task names to correct dataset names
    task_name_mapping = {
        'social_iqa': 'social_i_qa',  # Use the correct dataset name
        'piqa': 'piqa'  # This one is correct
    }
    
    # Set environment variables to handle dataset loading issues
    import os
    os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "1"
    os.environ["PYTHONIOENCODING"] = "utf-8"
    os.environ["HF_DATASETS_OFFLINE"] = "0"
    
    # Clear problematic dataset cache
    import shutil
    cache_dirs = [
        os.path.expanduser("~/.cache/huggingface/datasets/social_i_qa"),
        os.path.expanduser("~/.cache/huggingface/datasets/social_iqa"),
        os.path.expanduser("~/.cache/huggingface/datasets/piqa"),
        "/tmp/hf_datasets_cache"
    ]
    for cache_dir in cache_dirs:
        if os.path.exists(cache_dir):
            try:
                shutil.rmtree(cache_dir)
                logger.info(f"Cleared cache directory: {cache_dir}")
            except Exception as e:
                logger.warning(f"Could not clear cache {cache_dir}: {e}")
    
    device = next(model.parameters()).device
    cpu = device.type == 'cpu'  # device.type gives 'cpu' or 'cuda' string
    if cpu:
        logger.info("Dispatching model to GPU.")
    tasks = tasks.copy()
    eval_results = {}
    
    # Handle wikitext separately
    if 'wikitext' in tasks:
        try:
            logger.info("Evaluating wikitext...")
            result_wikitext = eval_wikitext(model, tokenizer)
            eval_results.update({'wikitext': result_wikitext})
            logger.info(f"Wikitext completed: {result_wikitext}")
        except Exception as e:
            logger.error(f"Wikitext failed: {str(e)}")
            eval_results.update({'wikitext': {'error': str(e), 'status': 'FAILED'}})
        tasks.remove('wikitext')
    
    # Evaluate each remaining task individually to avoid CUDA memory errors
    if tasks:
        for task_name in tasks:
            logger.info(f"Starting evaluation of task: {task_name}")
            
            # Map to correct dataset name if needed
            actual_task_name = task_name_mapping.get(task_name, task_name)
            if actual_task_name != task_name:
                logger.info(f"Mapping {task_name} to {actual_task_name}")
            
            try:
                # Clear GPU memory before each task
                torch.cuda.empty_cache()
                
                # Check memory status
                if torch.cuda.is_available():
                    allocated = torch.cuda.memory_allocated() / 1024**3
                    logger.info(f"GPU Memory before {task_name}: {allocated:.2f}GB")
                
                # Create fresh model wrapper for each task with error handling
                try:
                    lm = HFLM(model,
                              tokenizer=tokenizer,
                              device=device,
                              dtype=torch.bfloat16,
                              trust_remote_code=True,
                              batch_size=1  # Use smaller batch size for stability
                              )
                except Exception as wrapper_e:
                    logger.error(f"Failed to create HFLM wrapper for {task_name}: {wrapper_e}")
                    eval_results[task_name] = {'error': f'HFLM wrapper creation failed: {str(wrapper_e)}', 'status': 'FAILED'}
                    continue

                # Evaluate single task with UTF-8 error handling
                try:
                    results = simple_evaluate(
                        model=lm,
                        tasks=[actual_task_name],  # Use the mapped task name
                        batch_size=1,  # Conservative batch size
                        device=str(device),
                    )
                    
                    if results and 'results' in results:
                        # Map the result back to the original task name
                        if actual_task_name in results['results']:
                            eval_results[task_name] = results['results'][actual_task_name]
                        else:
                            eval_results.update(results['results'])
                        logger.info(f"✅ Task {task_name} completed successfully")
                    else:
                        logger.warning(f"No results returned for {task_name}")
                        eval_results[task_name] = {'error': 'No results returned', 'status': 'FAILED'}
                        
                except UnicodeDecodeError as unicode_e:
                    logger.error(f"UTF-8 decoding error for {task_name}: {unicode_e}")
                    
                    # Try to reload dataset with explicit encoding and gzip handling
                    try:
                        logger.info(f"Attempting to fix UTF-8/gzip issues for {task_name}...")
                        
                        # Clear all related caches including the mapped name
                        task_caches = [
                            os.path.expanduser(f"~/.cache/huggingface/datasets/{task_name}"),
                            os.path.expanduser(f"~/.cache/huggingface/datasets/{actual_task_name}"),
                        ]
                        for task_cache in task_caches:
                            if os.path.exists(task_cache):
                                shutil.rmtree(task_cache)
                                logger.info(f"Cleared cache: {task_cache}")
                        
                        # Force offline mode off and retry
                        from datasets import config as ds_config
                        ds_config.HF_DATASETS_OFFLINE = False
                        
                        # Retry evaluation
                        results = simple_evaluate(
                            model=lm,
                            tasks=[actual_task_name],
                            batch_size=1,
                            device=str(device),
                        )
                        
                        if results and 'results' in results:
                            if actual_task_name in results['results']:
                                eval_results[task_name] = results['results'][actual_task_name]
                            else:
                                eval_results.update(results['results'])
                            logger.info(f"✅ Task {task_name} completed after UTF-8 fix")
                        else:
                            eval_results[task_name] = {'error': f'UTF-8 error: {str(unicode_e)}', 'status': 'FAILED'}
                            
                    except Exception as retry_e:
                        logger.error(f"Retry also failed for {task_name}: {retry_e}")
                        eval_results[task_name] = {'error': f'UTF-8: {str(unicode_e)}, Retry: {str(retry_e)}', 'status': 'FAILED'}
                
                except Exception as eval_e:
                    logger.error(f"Evaluation failed for {task_name}: {eval_e}")
                    eval_results[task_name] = {'error': str(eval_e), 'status': 'FAILED'}
                
            except RuntimeError as e:
                if "CUDA error" in str(e) or "out of memory" in str(e):
                    logger.error(f"❌ CUDA error in task {task_name}: {str(e)}")
                    eval_results[task_name] = {'error': str(e), 'status': 'FAILED'}
                else:
                    logger.error(f"❌ Runtime error in task {task_name}: {str(e)}")
                    eval_results[task_name] = {'error': str(e), 'status': 'FAILED'}
                    
            except Exception as e:
                logger.error(f"❌ Unexpected error in task {task_name}: {str(e)}")
                eval_results[task_name] = {'error': str(e), 'status': 'FAILED'}
            
            finally:
                # Clean up after each task
                torch.cuda.empty_cache()
                if torch.cuda.is_available():
                    allocated = torch.cuda.memory_allocated() / 1024**3
                    logger.info(f"GPU Memory after {task_name}: {allocated:.2f}GB")
        
        logger.setLevel(logging.INFO)
    
    if cpu:
        logger.info("Moving model back to CPU.")
        model = move_to_cpu(model)
    
    # Print summary
    successful_tasks = []
    failed_tasks = []
    
    for task_name, result in eval_results.items():
        if isinstance(result, dict) and ('error' in result or 'status' in result):
            failed_tasks.append(task_name)
        else:
            successful_tasks.append(task_name)
    
    logger.info(f"Evaluation Summary - Successful: {len(successful_tasks)}, Failed: {len(failed_tasks)}")
    if failed_tasks:
        logger.info(f"Failed tasks: {', '.join(failed_tasks)}")
    
    return eval_results

def log_results(prefix: str, eval_results: dict):
    logger.info(prefix)
    msg = pformat(eval_results)
    msg = msg.split('\n')
    for m in msg:
        logger.info(f'\t{m}')
        
def log_results_lower_version(prefix: str, eval_results: dict):
    logger.info(prefix)
    
    # Format results with higher precision
    formatted_results = format_results_with_precision(eval_results, precision=4)
    
    msg = pformat(formatted_results)
    msg = msg.split('\n')
    for m in msg:
        logger.info(f'\t{m}')

def format_results_with_precision(results, precision=4):
    """
    Recursively format numerical results with specified precision
    """
    import numpy as np
    
    if isinstance(results, dict):
        formatted = {}
        for key, value in results.items():
            if isinstance(value, dict):
                formatted[key] = format_results_with_precision(value, precision)
            elif isinstance(value, (float, int)) and not isinstance(value, bool):
                if key in ['acc', 'acc_norm', 'acc_stderr', 'acc_norm_stderr', 'bleu', 'rouge1', 'rouge2', 'rougeL', 'bertscore']:
                    formatted[key] = round(float(value), precision)
                else:
                    formatted[key] = value
            elif hasattr(value, 'item') and not isinstance(value, bool):  # numpy types, but not bool
                try:
                    if key in ['acc', 'acc_norm', 'acc_stderr', 'acc_norm_stderr', 'bleu', 'rouge1', 'rouge2', 'rougeL', 'bertscore']:
                        formatted[key] = round(float(value.item()), precision)
                    else:
                        formatted[key] = value.item() if hasattr(value, 'item') else value
                except (AttributeError, TypeError):
                    formatted[key] = value
            else:
                formatted[key] = value
        return formatted
    elif isinstance(results, (float, int)) and not isinstance(results, bool):
        return round(float(results), precision)
    elif hasattr(results, 'item') and not isinstance(results, bool):  # numpy types, but not bool
        try:
            return round(float(results.item()), precision)
        except (AttributeError, TypeError):
            return results
    else:
        return results

def setup_logger(path):
    # Convert string path to Path object if needed
    path = Path(path) if isinstance(path, str) else path
    
    # Create directory if it doesn't exist
    path.mkdir(parents=True, exist_ok=True)
    
    logger = logging.getLogger()
    logger.addHandler(logging.FileHandler(path / "log.txt", mode='w'))
    logger.addHandler(TqdmLoggingHandler(sys.stdout))
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    for handler in logger.handlers:
        handler.setFormatter(formatter)
    logger.setLevel(logging.INFO)
    sys.excepthook = log_exception
    
class TqdmLoggingHandler(logging.StreamHandler):
    def emit(self, record):
        try:
            msg = self.format(record)
            console = Console()
            console.print(msg)
        except Exception:
            self.handleError(record)
            
def log_exception(exc_type, exc_value, exc_traceback):
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return
    
    sh = [h for h in logger.handlers if isinstance(h, TqdmLoggingHandler)]
    for h in sh:
        logger.removeHandler(h)

    try:
        logger.error("Exception Occured:\n", exc_info=(exc_type, exc_value, exc_traceback))
    finally:
        for h in sh:
            logger.addHandler(h)
    sys.__excepthook__(exc_type, exc_value, exc_traceback)

def move_to_cpu(model):
    """
    Move model to CPU, ensuring all parameters and buffers are properly moved.
    """
    # Move all parameters to CPU
    for param in model.parameters():
        param.data = param.data.cpu()
    
    # Move all standard buffers to CPU
    for buffer in model.buffers():
        buffer.data = buffer.data.cpu()
    
    # Clear the device map if it exists
    if hasattr(model, 'hf_device_map'):
        delattr(model, 'hf_device_map')
    
    # Move the model itself to CPU
    model = model.cpu()
    torch.cuda.empty_cache()
    return model