import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import math
import re
import pandas as pd
import json
import time
import random
import sys
import os
import torch
import numpy as np
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import torch.nn.functional as F
from datasets import load_dataset
from simplified_evaluator.eval import parse_prediction
from util import is_equiv
import multiprocessing as mp
from torch.multiprocessing import set_start_method


# Modified function to handle batches
def calculate_comprehensive_metrics_batched(model, input_ids_batch, attention_mask_batch):
    """Calculate various metrics for the model's input, handling batches."""
    metrics_list = []
    batch_size = input_ids_batch.shape[0]

    for param in model.parameters():
        param.requires_grad = False

    # Get input embeddings
    all_embeds = model.get_input_embeddings()(input_ids_batch) # [batch_size, seq_len, hidden_dim]
    all_embeds.requires_grad_(True)

    # Forward pass
    outputs = model(inputs_embeds=all_embeds, attention_mask=attention_mask_batch)
    logits = outputs.logits # [batch_size, seq_len, vocab_size]
    # Use attentions from the last layer
    seq_length = all_embeds.size(1)
    # --- Loss Calculation (Next Token Prediction Loss on the generated part) ---
    # Shift logits and labels for next token prediction
    shift_logits = logits[:, :-1, :].contiguous() # [batch_size, seq_len-1-end, vocab_size]
    shift_labels = input_ids_batch[:, 1:].contiguous() # [batch_size, seq_len-1-end]

    # Flatten the tokens
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    vocab_size = shift_logits.size(-1)

    # Calculate loss per token, considering only non-padded tokens in the samples
    flat_logits = shift_logits.view(-1, vocab_size)
    flat_labels = shift_labels.view(-1)
    per_token_loss = loss_fct(flat_logits, flat_labels) # [batch_size * (seq_len-1-end)]

    # Reshape and apply mask
    per_token_loss = per_token_loss.view(batch_size, -1) # [batch_size, seq_len-1-end]
    shift_mask = attention_mask_batch[:, 1:]
    masked_loss = per_token_loss * shift_mask

    # Calculate per-sample average loss and perplexity (avoid division by zero if mask sum is 0)
    valid_tokens_per_sample = shift_mask.sum(dim=1)
    per_sample_loss_sum = masked_loss.sum(dim=1)
    per_sample_avg_loss = - per_sample_loss_sum / valid_tokens_per_sample
    loss = per_sample_avg_loss.sum()
    grads = torch.autograd.grad(
            outputs=loss,
            inputs=all_embeds, # Gradients w.r.t the question embeddings
            retain_graph=False,
            create_graph=False,
            allow_unused=True # Allow unused if no valid samples contributed to loss
        )[0] # [batch_size, end, hidden_dim]

    # --- Metric Extraction per Sample ---
    # Detach tensors before converting to numpy/list

    per_sample_avg_loss_cpu = per_sample_avg_loss.cpu().detach().numpy()
    per_sample_perplexity_cpu = (torch.exp(masked_loss).sum(dim=1) / valid_tokens_per_sample).cpu().detach().numpy()

    for i in range(batch_size):
        metrics = {}
        sample_grads = grads[i] # [end, hidden_dim]
        position_gradients = sample_grads.norm(2, dim=1).tolist()
        all_tokens_gradient_norm = sample_grads.norm(2).item()

        metrics.update({
            'each_token_gradient_norm': position_gradients,
            'all_tokens_gradient_norm': all_tokens_gradient_norm,
            'length': seq_length, # Length of the original question
            'entropy_loss': per_sample_avg_loss_cpu[i].item(),
            'perplexity': per_sample_perplexity_cpu[i].item()
        })
        metrics_list.append(metrics)

    return metrics_list


def calculate_comprehensive_metrics(model, input_ids):
    """Calculate various metrics for the model's input"""
    metrics = {}

    for param in model.parameters():
        param.requires_grad = False
    # Get input embeddings
    all_embeds = model.get_input_embeddings()(input_ids)
    # Only use the question tokens to calculate the gradient

    question_embeds = all_embeds
    question_embeds.requires_grad_(True)  # Only this tensor needs gradient tracking
    input_embeds = question_embeds

    # Forward pass
    outputs = model(inputs_embeds=input_embeds, output_attentions=True)
    logits = outputs.logits
    att_weights = outputs.attentions[-1]
    att_scores = np.mean(att_weights.cpu().detach().numpy(), axis=(0, 1, 2)).tolist()

    seq_length = question_embeds.size(1)

    ce_loss = F.cross_entropy(logits[0, :-1, :], input_ids[0, 1:], reduction='none')
    loss = torch.mean(ce_loss)
    with torch.no_grad():
        perplexity = torch.exp(loss).item()

    grads = torch.autograd.grad(
        outputs=loss,
        inputs=question_embeds,
        retain_graph=False,
        create_graph=False
    )[0]
    grads = grads[0] # Remove batch dimension
    # grads = grads[0,:-1] # Remove last position
    # Compute gradient norm for each position (if still needed)

    position_gradients = grads.norm(2, dim=1).tolist()

    metrics.update({
        'each_token_gradient_norm': position_gradients,
        'all_tokens_gradient_norm': grads.norm(2).item(),
        'length': seq_length,
        'attention_scores': att_scores,
        'entropy_loss': loss.item(),
        'perplexity': perplexity
    })
    return metrics


def process_subset(gpu_id, dataset_subset, dataset_name, model_name, samples_data, output_queue, initial_batch_size):
    """Process a subset of questions on a specific GPU with automatic batch size reduction on OOM"""
    # Set device
    device = torch.device(f"cuda:{gpu_id}")
    print(f"Process {os.getpid()} using GPU {gpu_id}")
    
    # Load model and tokenizer for this process
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model = model.to(device)
    
    output_dict = {}
    
    count = 0
    for row in dataset_subset:
        count += 1
        model.eval()
        
        # Extract problem content and ground truth based on dataset
        if dataset_name == "math500":
            prob_content = row["problem"]
            prob_level = row["level"]
            prob_type = row["subject"]
            gt = row["answer"]

            output_dict[prob_content] = {
                'level': prob_level,
                'type': prob_type,
                'ground_truth': gt
            }
        elif dataset_name == "gsm8k":
            prob_content = row["question"]
            gt = row["answer"].split('####')[-1].strip()
            output_dict[prob_content] = {
                'ground_truth': gt
            }

        # Skip if this problem doesn't have samples
        if prob_content not in samples_data:
            print(f"GPU {gpu_id} - Warning: No samples found for problem: {prob_content[:50]}...")
            continue

        # Process input metrics with OOM handling
        try:
            # Get metrics for input
            messages = [
                {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
                {"role": "user", "content": prob_content}
            ]
            prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            encodings = tokenizer([prompt], return_tensors='pt')
            input_ids = encodings['input_ids'].to(device)
            metrics = calculate_comprehensive_metrics(model, input_ids)
            output_dict[prob_content].update({
                'input_metrics': metrics,
            })
            model.zero_grad()

            # Read Generations from the saved dir
            all_samples = samples_data[prob_content]
            
            tokenizer.pad_token = tokenizer.eos_token
            sample_encodings = tokenizer(all_samples, return_tensors='pt', padding=True, truncation=True, max_length=513)
            sample_ids_batch = sample_encodings['input_ids'].to(device)
            sample_attention_mask = sample_encodings['attention_mask'].to(device)

            num_samples = sample_ids_batch.shape[0]
            input_len = input_ids.shape[1]

            # Repeat input_ids and create combined attention mask
            input_ids_repeated = input_ids.repeat(num_samples, 1) # [num_samples, input_len]
            # Create attention mask for the repeated input part (all ones)
            input_attention_mask = torch.ones_like(input_ids_repeated)

            # Concatenate input and samples
            output_ids_batch = torch.cat([input_ids_repeated, sample_ids_batch], dim=1)
            output_attention_mask_batch = torch.cat([input_attention_mask, sample_attention_mask], dim=1)

            # Use dynamic batch size starting with initial_batch_size
            current_batch_size = initial_batch_size
            
            i = 0
            while i < num_samples:
                success = False
                
                # Try with current batch size, halve it if OOM occurs
                while not success and current_batch_size >= 1:
                    try:
                        end_idx = min(i + current_batch_size, num_samples)
                        batch_slice = slice(i, end_idx)
                        
                        print(f"GPU {gpu_id} - Processing {count}th question, samples {i}-{end_idx-1} with batch_size={current_batch_size}", flush=True)
                        
                        b_output_attention_mask = output_attention_mask_batch[batch_slice]
                        b_output_ids = output_ids_batch[batch_slice]
                        
                        # Clear CUDA cache before processing
                        torch.cuda.empty_cache()
                        
                        batched_metrics = calculate_comprehensive_metrics_batched(
                            model,
                            b_output_ids,
                            b_output_attention_mask,
                        )
                        model.zero_grad()
                        
                        # Process each sample in the batch
                        for j, sample_idx in enumerate(range(i, end_idx)):
                            if sample_idx >= len(all_samples):
                                break
                                
                            sample = all_samples[sample_idx]
                            answer = parse_prediction(sample, gt, 'math')
                            try:
                                is_correct = is_equiv(gt, answer)
                            except Exception:
                                is_correct = False
                            
                            output_dict[prob_content].update({
                                f"output_metrics_{sample_idx}": batched_metrics[j],
                                f"answer_{sample_idx}": answer,
                                f"acc_{sample_idx}": is_correct,
                            })
                        
                        # Batch processed successfully
                        success = True
                        i = end_idx
                        
                    except torch.cuda.OutOfMemoryError:
                        # OOM error - reduce batch size and retry
                        torch.cuda.empty_cache()  # Clear CUDA cache
                        new_batch_size = max(current_batch_size - 1, 1)  # Reduce batch size, minimum 1
                        print(f"GPU {gpu_id} - CUDA OOM! Reducing batch size from {current_batch_size} to {new_batch_size}", flush=True)
                        current_batch_size = new_batch_size
                    
                    except Exception as e:
                        # Handle other errors
                        print(f"GPU {gpu_id} - Error processing batch: {str(e)}", flush=True)
                        # Try to continue with next sample
                        success = True
                        i = min(i + 1, num_samples)
                
                # If we've reduced batch size to 1 and still can't process, skip this sample
                if not success:
                    print(f"GPU {gpu_id} - Failed to process sample {i} even with batch_size=1, skipping", flush=True)
                    i += 1
        
        except Exception as e:
            print(f"GPU {gpu_id} - Error processing question {count}: {str(e)}", flush=True)
            # Continue with next question
            continue
    
    # Put results in the queue
    output_queue.put(output_dict)
    print(f"GPU {gpu_id} finished processing {count} questions")
    return output_dict

def run_multi_gpu(model_name, tokenizer_name, dataset_name, samples_data, num_gpus, initial_batch_size):
    """Run the processing using multiple GPUs"""
    if dataset_name == "math500":
        ds = load_dataset("HuggingFaceH4/MATH-500")['test']
    elif dataset_name == "gsm8k":
        ds = load_dataset("gsm8k", "main", split="test")
    
    random.seed(0)
    
    # Count available GPUs and adjust if needed
    available_gpus = torch.cuda.device_count()
    if available_gpus < num_gpus:
        print(f"Warning: Requested {num_gpus} GPUs but only {available_gpus} are available")
        num_gpus = available_gpus
    
    if num_gpus == 0:
        print("No GPUs available. Exiting.")
        return {}
    
    # Convert dataset to list for easier splitting
    dataset_list = list(ds)
    
    # Divide dataset into chunks for each GPU
    chunk_size = len(dataset_list) // num_gpus
    dataset_chunks = [dataset_list[i:i + chunk_size] for i in range(0, len(dataset_list), chunk_size)]
    
    # If we have more chunks than GPUs, combine the last chunks
    if len(dataset_chunks) > num_gpus:
        dataset_chunks[num_gpus-1:] = [sum(dataset_chunks[num_gpus-1:], [])]
    
    # Create a Queue to get results from processes
    output_queue = mp.Queue()
    
    # Create and start processes
    processes = []
    for gpu_id in range(num_gpus):
        p = mp.Process(
            target=process_subset,
            args=(gpu_id, dataset_chunks[gpu_id], dataset_name, model_name, samples_data, output_queue, initial_batch_size)
        )
        processes.append(p)
        p.start()
    
    # Collect results from all processes
    results = {}
    for _ in range(num_gpus):
        chunk_results = output_queue.get()
        results.update(chunk_results)
    
    # Wait for all processes to finish
    for p in processes:
        p.join()
    
    return results

def convert_to_serializable(obj):
    """Convert numpy and torch objects to serializable types"""
    if isinstance(obj, np.float32):
        return float(obj)
    raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")


if __name__ == "__main__":
    # Try to set start method for multiprocessing to avoid CUDA initialization issues
    try:
        set_start_method('spawn')
    except RuntimeError:
        pass
    
    EXP_NAME = sys.argv[1]
    dataset = sys.argv[2]
    MODEL_NAME = sys.argv[3]
    SAMPLE_PATH = sys.argv[4]
    BATCH_SIZE = int(sys.argv[5])
    
    # Get number of GPUs to use (default to all available)
    NUM_GPUS = int(sys.argv[6]) if len(sys.argv) > 6 else torch.cuda.device_count()
    
    # create a directory if not exists
    try:
        os.mkdir(EXP_NAME)
    except FileExistsError:
        pass

    # Load samples data
    with open(SAMPLE_PATH, 'r', encoding="utf-8") as f:
        samples_data = json.load(f)
        
    print(f"Using {NUM_GPUS} GPUs for processing")
    try:
        response_dict = run_multi_gpu(MODEL_NAME, MODEL_NAME, dataset, samples_data, NUM_GPUS, BATCH_SIZE)
        
        # Write results to file
        with open(EXP_NAME+'/output.json', 'w', encoding="utf-8") as f:
            json.dump(response_dict, f, default=convert_to_serializable)
        print("Successfully completed processing and saved results")
        
    except Exception as e:
        print(f"Error in multi-GPU processing: {str(e)}")
        
        # Try to save any partial results if available
        if 'response_dict' in locals() and response_dict:
            try:
                with open(EXP_NAME+'/partial_output.json', 'w', encoding="utf-8") as f:
                    json.dump(response_dict, f, default=convert_to_serializable)
                print("Saved partial results")
            except Exception as save_err:
                print(f"Could not save partial results: {str(save_err)}")