import argparse
import torch
import json
import os
import pickle
import numpy as np
import logging
import random
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate Qwen gradients on Negated LAMA")
    parser.add_argument("--model_name", type=str, required=True, help="HuggingFace model name or path")
    parser.add_argument("--data_dir", type=str, default="data/TREx", help="Directory containing LAMA relations")
    parser.add_argument("--output_dir", type=str, default="output_grads", help="Directory to save results")
    parser.add_argument("--max_samples", type=int, default=100, help="Max samples per relation to evaluate")
    parser.add_argument("--subset", type=str, default=None, help="Specific relation to run (e.g., P19)")
    
    # New arguments
    parser.add_argument("--eval_mode", type=str, default="span_obj", choices=["span_obj", "span_logprob"], 
                        help="Evaluation mode: 'span_obj' (default) captures grad of loss at object token (hidden state). "
                             "'span_logprob' captures grad of log-prob of target span w.r.t params.")
    parser.add_argument("--grad_scope", type=str, default="default", choices=["default", "all"], 
                        help="Parameter subset for gradient. 'default'=last_block+lm_head.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for data sampling")
    
    return parser.parse_args()

def load_data(data_dir, relation_file="relations.jsonl"):
    relations = []
    with open(os.path.join(data_dir.replace("TREx", ""), relation_file), 'r') as f:
        for line in f:
            relations.append(json.loads(line))
    return relations

def get_relation_data(data_dir, relation_name):
    path = os.path.join(data_dir, f"{relation_name}.jsonl")
    if not os.path.exists(path):
        return None
    data = []
    with open(path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def parse_template(template, subject_label):
    if "[Y]" not in template:
        return None
    ids = template.split("[Y]")
    prefix = ids[0]
    return prefix.replace("[X]", subject_label).strip()

def get_target_token_id(tokenizer, target_label):
    # For span_obj mode: get first token of target
    ids = tokenizer.encode(" " + target_label, add_special_tokens=False)
    if len(ids) == 0:
        ids = tokenizer.encode(target_label, add_special_tokens=False)
    if len(ids) == 0:
        return None
    return ids[0]

def compute_span_obj_gradient(model, tokenizer, prompt, target_token_id, device):
    """
    Method: span_obj (Legacy)
    Computes gradient of Loss(target_token) w.r.t last hidden state of prompt.
    """
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    outputs = model(**inputs, output_hidden_states=True)
    last_hidden_state = outputs.hidden_states[-1] 
    
    # Last token hidden state
    last_token_hidden = last_hidden_state[:, -1, :]
    
    logits = model.lm_head(last_token_hidden)
    target = torch.tensor([target_token_id]).to(logits.device)
    loss = F.cross_entropy(logits, target)
    
    # Grad w.r.t activation (last_token_hidden)
    grad = torch.autograd.grad(loss, last_token_hidden, retain_graph=False)[0]
    
    return grad.detach().cpu().numpy().flatten(), loss.item()

def compute_span_logprob_gradient(model, tokenizer, prompt, target_label, device, grad_scope="default"):
    """
    Method: span_logprob (New)
    Computes gradient of LogProb(target_span | prompt) w.r.t parameters.
    LogProb = sum( log p(y_t | x, y_<t) )
    This is equivalent to -CrossEntropySum(logits, targets).
    
    Refined Tokenization Strategy:
    concatenating [prompt_tokens] + [target_tokens] ensures strict boundary.
    """
    
    # 1. Prepare Inputs using Manual Concatenation
    prompt_ids = tokenizer(prompt, add_special_tokens=False).input_ids
    # " " + target_label ensures we get the tokenization of the word start (if space-prefixed).
    target_ids = tokenizer(" " + target_label, add_special_tokens=False).input_ids
    
    if not target_ids:
        # Fallback if empty (e.g. pure whitespace?)
        target_ids = tokenizer(target_label, add_special_tokens=False).input_ids
    
    full_ids = prompt_ids + target_ids
    input_ids = torch.tensor([full_ids], dtype=torch.long).to(device)
    
    start_idx = len(prompt_ids)
    
    # Create Labels
    labels = input_ids.clone()
    # Mask prompt (everything before start_idx)
    labels[:, :start_idx] = -100
    
    # 2. Forward Pass
    outputs = model(input_ids, labels=labels)
    
    # Recompute loss sum manually to be precise
    logits = outputs.logits
    # Shift: tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
    
    loss_fct = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=-100)
    nll_sum = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    score = -nll_sum
    
    # 3. Compute Gradients
    if grad_scope == "default":
        params_to_diff = []
        # Try finding last layer
        base_model = getattr(model, "model", None) or getattr(model, "transformer", None)
        layers = getattr(base_model, "layers", None) or getattr(base_model, "h", None)
        
        if layers is not None:
             params_to_diff.extend(list(layers[-1].parameters()))
        
        if hasattr(model, "lm_head"):
            params_to_diff.extend(list(model.lm_head.parameters()))
            
        if not params_to_diff:
            logger.warning("Could not identify last layer params. Using all params.")
            params_to_diff = list(model.parameters())
            
    else: # scope == "all"
        params_to_diff = list(model.parameters())
        
    if not params_to_diff[0].requires_grad:
        for p in params_to_diff:
            p.requires_grad = True
            
    grads = torch.autograd.grad(score, params_to_diff, retain_graph=False)
    flat_grad = np.concatenate([g.detach().cpu().numpy().flatten() for g in grads])
    
    # Return debug info too
    debug_info = {
        "decoded_target": tokenizer.decode(target_ids),
        "target_ids": target_ids,
        "input_ids": input_ids.cpu().numpy().tolist()[0]
    }
    
    return flat_grad, score.item(), debug_info

def main():
    args = parse_args()
    
    # Set seed for reproducibility
    random.seed(args.seed)
    
    logger.info(f"Loading model: {args.model_name}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name, 
            device_map="auto", 
            trust_remote_code=True,
            torch_dtype=torch.float16
        )
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        return

    logger.info("Model loaded successfully.")
    
    # Load relations (metadata)
    relations_file = os.path.join(os.path.dirname(args.data_dir.rstrip("/")), "relations.jsonl")
    if not os.path.exists(relations_file):
         relations_file = os.path.join(args.data_dir, "relations.jsonl")
         
    if not os.path.exists(relations_file):
        logger.error(f"Could not find relations.jsonl at {relations_file}")
        return

    relations_meta = []
    with open(relations_file, 'r') as f:
        for line in f:
            relations_meta.append(json.loads(line))
    
    model_tag = args.model_name.strip("/").split("/")[-1]
    # Simple check to append only if it's the default name, to allow override
    if os.path.basename(args.output_dir) == "output_grads":
         args.output_dir = f"output_grads/{model_tag}/{args.eval_mode}_{args.grad_scope}"

    os.makedirs(args.output_dir, exist_ok=True)
    logger.info(f"Saving results to: {args.output_dir}")
    
    device = model.device 
    
    logger.info(f"Starting evaluation in mode: {args.eval_mode}")
    
    for rel_meta in relations_meta:
        rel_name = rel_meta['relation']
        
        if args.subset and rel_name != args.subset:
            continue
            
        # Check if output exists to skip (Resume capability)
        filename = f"{rel_name}_{args.eval_mode}_{args.grad_scope}_grads.pkl"
        if args.eval_mode == "span_obj": 
             filename = f"{rel_name}_grads.pkl"
             
        if os.path.exists(os.path.join(args.output_dir, filename)):
            logger.info(f"Skipping {rel_name}, output file already exists.")
            continue

        logger.info(f"Processing relation: {rel_name}")
        
        data = get_relation_data(args.data_dir, rel_name)
        if not data:
            logger.warning(f"No data found for {rel_name}")
            continue
            
        # Shuffle and Sample with fixed seed
        random.shuffle(data)
        data = data[:args.max_samples]
            
        template = rel_meta.get('template', "")
        template_negated = rel_meta.get('template_negated', "")
        
        if not template or not template_negated:
            continue
            
        rel_results = []
        count = 0
        pbar = tqdm(data) # Iterating over the sampled subset
        
        for i, sample in enumerate(pbar):
            subj = sample['sub_label']
            obj = sample['obj_label']
            
            prompt_normal = parse_template(template, subj)
            prompt_negated = parse_template(template_negated, subj)
            
            if not prompt_normal or not prompt_negated:
                continue
            
            try:
                if args.eval_mode == "span_obj":
                     target_id = get_target_token_id(tokenizer, obj)
                     if target_id is None: continue
                     
                     if count == 0:
                        logger.info(f"--- Sanity Check ({rel_name}) ---")
                        logger.info(f"Target Token (Decoded): '{tokenizer.decode([target_id])}' (ID: {target_id})")

                     grad_normal, loss_normal = compute_span_obj_gradient(model, tokenizer, prompt_normal, target_id, device)
                     grad_negated, loss_negated = compute_span_obj_gradient(model, tokenizer, prompt_negated, target_id, device)
                     
                     sim = cosine_similarity([grad_normal], [grad_negated])[0][0]
                     sim_flipped = cosine_similarity([grad_normal], [-grad_negated])[0][0]
                     
                     result = {
                        "sample_uuid": sample.get('uuid', i),
                        "subj": subj,
                        "obj": obj,
                        "target_token_id": target_id,
                        "cosine_sim": float(sim),
                        "cosine_sim_flipped": float(sim_flipped),
                        # "grad_normal": grad_normal, # REMOVED to save space
                        # "grad_negated": grad_negated
                    }

                elif args.eval_mode == "span_logprob":
                     # Compute for normal
                     grad_pos, score_pos, debug_pos = compute_span_logprob_gradient(model, tokenizer, prompt_normal, obj, device, args.grad_scope)
                     # Compute for negated
                     grad_neg, score_neg, debug_neg = compute_span_logprob_gradient(model, tokenizer, prompt_negated, obj, device, args.grad_scope)
                     
                     if count == 0:
                        logger.info(f"--- Sanity Check ({rel_name}) ---")
                        logger.info(f"Subject: {subj}")
                        logger.info(f"Object: {obj}") 
                        logger.info(f"Prompt (Normal): '{prompt_normal}'")
                        logger.info(f"Target Span (Decoded): '{debug_pos['decoded_target']}'")
                        logger.info(f"Target IDs: {debug_pos['target_ids']}")
                        logger.info(f"Input IDs Tail: {debug_pos['input_ids'][-5:]}")
                     
                     sim = cosine_similarity([grad_pos], [grad_neg])[0][0]
                     sim_flipped = cosine_similarity([grad_pos], [-grad_neg])[0][0]
                     
                     result = {
                        "sample_uuid": sample.get('uuid', i),
                        "subj": subj,
                        "obj": obj,
                        "cosine_sim": float(sim),          # cos_pos_neg
                        "cosine_sim_flipped": float(sim_flipped), # cos_pos_minus_neg
                        "score_pos": float(score_pos),
                        "score_neg": float(score_neg),
                    }
                
                rel_results.append(result)
                count += 1
                
            except Exception as e:
                logger.error(f"Error processing sample {i}: {e}")
                continue
        
        # Save results
        filename = f"{rel_name}_{args.eval_mode}_{args.grad_scope}_grads.pkl"
        if args.eval_mode == "span_obj": 
             filename = f"{rel_name}_grads.pkl"
             
        with open(os.path.join(args.output_dir, filename), 'wb') as f:
            pickle.dump(rel_results, f)
            
    logger.info("Evaluation complete.")

if __name__ == "__main__":
    main()
