#!/usr/bin/env python3

import numpy as np
import torch
from pathlib import Path
from typing import List, Tuple
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class VectorExtractor:
    def __init__(self, model_name: str, device: str = 'cuda'):
        self.model_name = model_name
        self.device = device
        self.model = None
        self.tokenizer = None
        
    def load_model(self):
        if 'gemma' in self.model_name.lower():
            model_id = 'google/gemma-7b-it'
        elif 'llama' in self.model_name.lower():
            model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
        elif 'mistral' in self.model_name.lower():
            model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
        else:
            raise ValueError(f"Unknown model: {self.model_name}")
        
        from transformers import AutoModelForCausalLM, AutoTokenizer
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            device_map='auto'
        )
        
    def extract_activations(self, texts: List[str], layer: int = -1) -> np.ndarray:
        activations = []
        
        for text in texts:
            inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs, output_hidden_states=True)
                hidden_states = outputs.hidden_states[layer]
                
                last_token_activation = hidden_states[0, -1, :].cpu().numpy()
                activations.append(last_token_activation)
        
        return np.array(activations)
    
    def compute_steering_vector(self, positive_texts: List[str], 
                              negative_texts: List[str], 
                              layer: int = -1) -> np.ndarray:
        
        pos_activations = self.extract_activations(positive_texts, layer)
        neg_activations = self.extract_activations(negative_texts, layer)
        
        pos_mean = np.mean(pos_activations, axis=0)
        neg_mean = np.mean(neg_activations, axis=0)
        
        steering_vector = pos_mean - neg_mean
        
        norm = np.linalg.norm(steering_vector)
        if norm > 0:
            steering_vector = steering_vector / norm
        
        return steering_vector

def extract_all_vectors(model_name: str, traits: List[str], 
                       data_dir: Path, output_dir: Path):
    
    extractor = VectorExtractor(model_name)
    extractor.load_model()
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    for trait in traits:
        logger.info(f"Extracting vectors for {trait}")
        
        data_file = data_dir / f'{trait}_pairs.json'
        if not data_file.exists():
            logger.warning(f"Data file not found: {data_file}")
            continue
        
        import json
        with open(data_file, 'r') as f:
            pairs = json.load(f)
        
        vectors = []
        for pos_text, neg_text in pairs[:2500]:
            vector = extractor.compute_steering_vector([pos_text], [neg_text])
            vectors.append(vector)
        
        vectors_array = np.array(vectors)
        
        output_file = output_dir / f'{model_name}_{trait}_vectors_webscale.npy'
        np.save(output_file, vectors_array)
        
        logger.info(f"Saved {len(vectors)} vectors to {output_file}")

def main():
    from ..utils.config import MODELS, TRAITS, BASE_DIR
    
    data_dir = BASE_DIR / 'prepared_data'
    output_dir = BASE_DIR / 'vectors' / 'raw'
    
    for model in MODELS:
        logger.info(f"Processing model: {model}")
        extract_all_vectors(model, TRAITS, data_dir, output_dir)
    
    logger.info("Vector extraction complete")

if __name__ == "__main__":
    main()