import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import argparse
import os
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import numpy as np
from IPython.display import clear_output

from matplotlib.lines import Line2D
import threading
import torch
import pynvml
import time

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

def get_power_usage():
    """Get current GPU power draw in watts."""
    return pynvml.nvmlDeviceGetPowerUsage(handle) / 1000

def measure_energy_during(func, *args, **kwargs):
    """Measure average GPU power and energy (joules) during execution of func."""
    power_samples = []
    start_time = time.time()

    def sample_power():
        while not stop_sampling.is_set():
            power_samples.append(get_power_usage())
            time.sleep(0.1)

    stop_sampling = threading.Event()
    sampler_thread = threading.Thread(target=sample_power)
    sampler_thread.start()

    # Run the target function
    result = func(*args, **kwargs)

    stop_sampling.set()
    sampler_thread.join()

    duration = time.time() - start_time
    avg_power = sum(power_samples) / len(power_samples) if power_samples else 0
    energy_joules = avg_power * duration
    return result, duration, avg_power, energy_joules

def generate_text(model, tokenizer, inputs, **gen_kwargs):
    """Generate tokens autoregressively."""
    return model.generate(**inputs, **gen_kwargs)

def score_sequence(model, input_ids, attention_mask=None):
    """Forward pass to compute logits without generation."""
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits
    


if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--prompts', nargs="+", type=str, required=False, default=["How are you?", "Tell me a story"])
    parser.add_argument('--seed', type=int, required=False, default=42)
    parser.add_argument('--model', type=str, required=False, default="meta-llama/Llama-3.2-1B-Instruct")

    args = parser.parse_args()
    model_cache = "../models"
    model_name = args.model
    
    
    
    if model_name== "meta-llama/Llama-3.2-1B-Instruct":
        model_str="Llama-3.2-1B-Instruct"
    if model_name== "meta-llama/Llama-3.2-3B-Instruct":
        model_str="Llama-3.2-3B-Instruct"

    if model_name== "mistralai/Ministral-8B-Instruct-2410":
        model_str="Ministral-8B-Instruct-2410"
        
    if model_name== "google/gemma-3-4b-it":
        model_str="Gemma-3-4b-it"
    if model_name== "google/gemma-3-1b-it":
        model_str="Gemma-3-1b-it"
    

    # Set random seed for reproducibility
    torch.manual_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float32, cache_dir=model_cache).to(device)

    final_results = []
    
    for prompt_idx, prompt_str in enumerate(args.prompts):
        
        print(f"Processing prompt {prompt_idx}...")
        #Generate a random number between 40 and 300 to use a max output length
        max_new_tokens = np.random.randint(100, 500)
        
        
        # Initialize NVML and get GPU handle
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
        
        messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt_str}
        ]
        
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,              # return as string
            add_generation_prompt=True
        )
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        # Measure energy during generation
        generated_outputs, gen_time, gen_power, gen_energy = measure_energy_during(
            generate_text,
            model,
            tokenizer,
            inputs,
            max_new_tokens=max_new_tokens,
            temperature=1.0,
            repetition_penalty=1.2,
            eos_token_id=tokenizer.eos_token_id,
            do_sample=True,
            use_cache=True,  # Enable KV cache
            pad_token_id=tokenizer.eos_token_id  # Important for Gemma
        )
        
        input_len = inputs["input_ids"].shape[1]
        generated_tokens = generated_outputs[0][input_len:].unsqueeze(0).to(model.device)

        number_gen_tok = generated_tokens.shape[1]
        generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

        #Prepare full sequence for scoring: prompt + generated tokens
        full_sequence = torch.cat([inputs["input_ids"], generated_tokens], dim=1).to(model.device)
        
        # Prepare attention mask if present, extended for generated tokens
        attention_mask = None
        if "attention_mask" in inputs:
            attn_mask = inputs["attention_mask"]
            gen_mask = torch.ones((1, generated_tokens.shape[1]), dtype=attn_mask.dtype).to(model.device)
            attention_mask = torch.cat([attn_mask, gen_mask], dim=1)

        # Measure energy during scoring (forward pass)
        logits, score_time, score_power, score_energy = measure_energy_during(
            score_sequence,
            model,
            full_sequence,
            attention_mask=attention_mask
        )

 
        print("Generation energy: {:.2f} J, Scoring energy: {:.2f} J".format(gen_energy, score_energy))
        print("Number of generated tokens:", number_gen_tok)
        
        final_results.append({"out_length": number_gen_tok, "gen_energy": gen_energy, "score_energy": score_energy, "decoded_text": generated_text})

        # Cleanup NVML when done (optional)
        pynvml.nvmlShutdown()
        
    # Save results to a file
    results_file = f"energy_results_{model_str}.pkl"
    with open(results_file, "wb") as f:
        pickle.dump(final_results, f)