import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm import tqdm
import numpy as np

class ActivationExtractor:
    def __init__(self, model_name, load_in_8bit=False):
        # 1. 
        if torch.cuda.is_available():
            self.device = "cuda"
            print(f"[*] CUDA Detected. Using GPU: {torch.cuda.get_device_name(0)}")
        else:
            self.device = "cpu"
            print("[!] CUDA NOT detected. Falling back to CPU. (This will be slower but works for Qwen-1.5B)")

        print(f"[*] Loading model: {model_name}...")
        
        # 2. 
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True, trust_remote_code=True)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
        except Exception as e:
            print(f"[!] Tokenizer load error: {e}")
            raise e
        
        # 3. 
        if load_in_8bit and self.device == "cuda":
            print("[*] Loading model in 8-bit to save VRAM...")
            quantization_config = BitsAndBytesConfig(load_in_8bit=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=quantization_config, # 
                device_map="auto",
                output_hidden_states=True,
                trust_remote_code=True
            )
        else:
            # 
            dtype = torch.float16 if self.device == "cuda" else torch.float32
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=dtype,
                device_map=self.device,
                output_hidden_states=True,
                trust_remote_code=True
            )
        self.model.eval()

    def process_batch(self, texts, layer_indices):
        results = {l: [] for l in layer_indices}
        
        # 
        for text in tqdm(texts, desc="Extracting"):
            try:
                inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
                with torch.no_grad():
                    outputs = self.model(**inputs)
                
                for layer_idx in layer_indices:
                    # 
                    if layer_idx >= len(outputs.hidden_states):
                         print(f"[!] Warning: Layer {layer_idx} out of bounds. Max layer is {len(outputs.hidden_states)-1}")
                         continue
                         
                    vec = outputs.hidden_states[layer_idx][0, -1, :].cpu().numpy().astype(np.float32)
                    results[layer_idx].append(vec)
            except Exception as e:
                print(f"[!] Error processing sample. Error: {e}")
                continue
                
        final_results = {}
        for l, vecs in results.items():
            if len(vecs) > 0:
                final_results[l] = np.array(vecs)
            else:
                final_results[l] = np.array([])
            
        return final_results