import torch
import torch.nn.functional as F
import pandas as pd
from config import TOKEN, map_model_name
from pathlib import Path
import argparse
from tqdm import tqdm
import numpy as np
import os, sys
import pickle
from utils import str_to_bool, get_index_of_last_no_padding_token, get_model_id, get_layers_to_process
import gc
from torch.utils.data import Dataset, DataLoader
from data import prepare_tokenized_data
from model import ProjectionLayer, load_projection, ModelWithProj,load_model_and_tokenizer

# Add new TextDataset class for tokenized data
class TextDataset(Dataset):
    """Dataset for tokenized text data to use with DataLoader"""
    def __init__(self, tokenizer, texts, max_length=None):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx]
    
    def collate_fn(self, batch):
        # Apply tokenization to the batch
        tokenizer_kwargs = {
            "padding": "max_length" if self.max_length else "longest",
            "truncation": True,
            "return_tensors": "pt"
        }
        if self.max_length:
            tokenizer_kwargs["max_length"] = self.max_length
            
        encoded = self.tokenizer.batch_encode_plus(
            batch, 
            **tokenizer_kwargs
        )
        return encoded




def get_target_prob(model,tokenizer, inputs, target, max_length=None):
    """
    Get probability of target token(s) across a batch of prompts.
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        inputs: Either a list of text prompts or a pre-tokenized batch
        target: Target token or list of target tokens to get probability for
        max_length: Maximum sequence length for tokenization
    
    Returns:
        If target is a string: numpy array of probabilities for that target
        If target is a list: list of numpy arrays, one for each target
    """
    # Check if target is a single string or list
    single_target = isinstance(target, str)
    if single_target:
        targets = [target]
    else:
        targets = target
    
    # get the token_id of pad_token
    pad_token_id = tokenizer.pad_token_id

    # Get model's raw predictions with memory optimization
    outputs = model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        output_hidden_states=False,
    )

    # Get just the logits
    logits = outputs.logits
    
    # Get indices of last non-padding tokens
    last_non_pad_token_indeces = torch.stack([get_index_of_last_no_padding_token(input_ids, pad_token_id) 
                                            for input_ids in inputs['input_ids']]).squeeze(-1)

    # Extract logits at the last non-padding positions
    relevant_logits = torch.stack([logits[i, last_non_pad_token_indeces[i], :] for i in range(len(inputs['input_ids']))])

    # Convert to probabilities - relevant logits is (batch_size, vocab_size)
    probs = relevant_logits.softmax(dim=-1)
    
    # Clean up GPU memory immediately
    del logits, outputs, relevant_logits
    
    # Get target probabilities for each target
    target_probs_list = []
    for t in targets:
        target_id = tokenizer.encode(t, add_special_tokens=False)[0]
        target_probs = probs[:, target_id].cpu().numpy()
        target_probs_list.append(target_probs)
    
    # Final cleanup
    del probs
    
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
   
    gc.collect()
    
    # Return single array for single target, or list for multiple targets
    if single_target:
        return target_probs_list[0]
    else:
        return target_probs_list

def get_gender_score_batched(model, tokenizer, prompts, batch_size=32, max_length=512):
    """
    Calculate gender score for a list of prompts in batches using efficient data loading.
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        prompts: List of text prompts
        batch_size: Batch size for processing
        max_length: Maximum sequence length
        
    Returns:
        Tuple of (gender_scores, he_probs, she_probs)
    """
    # Prepare data loader for efficient batching
    # prepare_tokenized_data(X, tokenizer, args.device, args.batch_size, max_length=max_length)
    dataloader = prepare_tokenized_data(prompts, tokenizer, model.device, batch_size, max_length=max_length)
    print('Tokenized data')
    
    gender_scores = []
    he_probs = []
    she_probs = []
    
    # Process batches
    with torch.no_grad():
        for batch_idx, inputs in enumerate(tqdm(dataloader, desc="Processing batches")):

            # Get probabilities for 'he' and 'she' in one call
            he, she = get_target_prob(model, tokenizer, inputs, ["he", "she"])

            # Calculate gender score
            batch_scores = he - she
            gender_scores.extend(batch_scores)
            he_probs.extend(he)
            she_probs.extend(she)
            
    return gender_scores, he_probs, she_probs

def calc_gender_score(model, tokenizer, dataset_path, batch_size, 
                    max_length=None):
    """Process dataset and add model predictions"""
    # Read dataset
    df = pd.read_csv(dataset_path)

    # get the prompts
    prompts = df['prompt'].tolist()

    # If max_length is not provided, calculate it from the prompts
    if max_length is None:
        max_length = max(len(tokenizer.encode(prompt)) for prompt in prompts)
        print(f"Using calculated max_length: {max_length}")
    else:
        print(f"Using provided max_length: {max_length}")

    # turn model to eval mode
    model.eval()
    
    # Calculate gender scores
    print("Calculating gender scores...")
    gender_scores, he_probs, she_probs = get_gender_score_batched(
            model, tokenizer, prompts, batch_size, max_length=max_length)
   

    df['gender_score'] = gender_scores
    df['he_prob'] = he_probs
    df['she_prob'] = she_probs

    return df
    

def main(args):    
    # Set device based on args
    if args.device == "auto":
        device_map = "auto"
    else:
        # Use specific device (cuda, cpu, mps)
        device_map = args.device
    
    # Load model and tokenizer
    print(f"Loading model on device: {device_map}")
    model, tokenizer = load_model_and_tokenizer(
            model_name=args.model_name,
            model_type=args.model_type,
            task_type="causal-lm" if args.model_type == "llama" else None,
            device_map=args.device,
            torch_dtype=args.torch_dtype
    )


   

    # if apply_projection is True, then define model as model
    if args.apply_projection:
        model = ModelWithProj(model, model_type=args.model_type)

        projection_params = {
            "model_name": args.model_name,
            "model_type": args.model_type,
            "dataset": args.dataset,
            "projection_method": args.projection_method,
            "layers": get_layers_to_process(args.layers, model, args.model_type),
            "embedding_strategy": args.embedding_strategy
        }
    else:
        projection_params = None


    # Apply projection if requested
    if args.apply_projection and projection_params is not None:

        apply_strategy = args.embedding_strategy if args.apply_strategy == 'same' else args.apply_strategy

        # Handle multiple layers
        for layer_id in projection_params["layers"]:
            # Determine folder structure based on args.layers
            if args.layers == "all":
                layers_folder = "all"
            elif args.layers == "lm_head":
                layers_folder = "lm_head"
            elif args.layers.startswith("last_") and args.layers[5:].isdigit():
                layers_folder = args.layers
            else:
                layers_folder = "custom"
                
            # Load the projection
            projection = load_projection(
                model_name=projection_params["model_name"],
                dataset=projection_params["dataset"],
                device=model.device,
                projection_method=projection_params["projection_method"],
                layer_id=layer_id,
                embedding_strategy=projection_params["embedding_strategy"],
                projections_dir=f"projections/{projection_params['dataset']}",
                layer_folder=layers_folder,
                independent_layers=args.independent_layers,
            )
            
            model.register_projection_hook(layer_id=layer_id, projection=projection, apply_strategy=apply_strategy)
    else:
        print("No projection applied")

    
    
    
    # set the dataset path
    if args.dataset == 'dama':
        dataset_path = 'data/dama/dama_professions.csv'
    elif args.dataset == 'dama_mixed':
        dataset_path = 'data/dama/dama_professions_mixed.csv'
    # Process dataset
    df = calc_gender_score(
        model, 
        tokenizer, 
        dataset_path, 
        args.batch_size,
        max_length=args.max_length
    )

    # remove hooks
    if args.apply_projection and projection_params:
       model.remove_hooks()
    # map model name
    model_name_short = map_model_name(args.model_name)



    # Modify output path if projection was used
    suffix = ""
    if args.apply_projection:
        if args.layers == "lm_head":
            suffix = f"_{args.projection_method}_lm_head"
        elif args.layers == "all":
            suffix = f"_{args.projection_method}_all"
        elif args.layers.startswith("last_"):
            suffix = f"_{args.projection_method}_{args.layers}"
        else:
            layers_str = "_".join([str(layer) for layer in projection_params["layers"]])
            suffix = f"_{args.projection_method}_{layers_str}"
            
        # Add embedding and apply strategy to suffix
        suffix += f"_{args.embedding_strategy}_independent_{args.independent_layers}"
    

    
    output_path = f'data/result_data/{args.dataset}_professions_{model_name_short}{suffix}.csv'
    df.to_csv(output_path, index=False)
    print(f"Results saved to {output_path}")

if __name__ == "__main__":
    # Add progress bar to pandas
    tqdm.pandas()
    
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True, choices=["dama", "dama_mixed", "bios", "toy"],
                    help="Dataset to use for gender score calculation")
    parser.add_argument('--model_name', type=str, default='meta-llama/Llama-2-7b-hf',
                      help='Model name to use')
    parser.add_argument('--model_type', type=str, default='llama',
                      help='Model type (llama or bert)')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--max_length', type=int, default=None,
                      help='Maximum sequence length for tokenization (calculated from data if not provided)')
    parser.add_argument('--apply_projection', type=str, default='False',
                      help='Whether to apply projection during evaluation')
    parser.add_argument('--projection_method', type=str, default='LEACE',
                      help='Projection method used (LEACE, LEACE-no-whitening, etc.)')
    parser.add_argument('--layers', type=str, default='lm_head',
                      help="Layers to apply projection: 'last', 'all', 'lm_head', 'last_x' (where x is a number), or specific layers like '0,1,2'")
    parser.add_argument('--embedding_strategy', type=str, default='mean',
                      help='Embedding strategy (mean, last, etc.)')
    parser.add_argument("--apply_strategy", type=str, default="all", choices=["all", "last_non_pad","same" ],
                        help="Which layers to apply the projection to")
    parser.add_argument('--device', type=str, default='mps',
                      help='Device to use (auto, cuda, cpu, mps)')
    parser.add_argument("--torch_dtype", type=str, default="float16", 
                        choices=["float16", "float32", "bfloat16"],
                        help="Precision to use for computations")
    parser.add_argument('--independent_layers', type=str, default='False',
                      help='Whether to use independent layers for projection')
    args = parser.parse_args()
        
    # Convert apply_projection to boolean
    args.apply_projection = str_to_bool(args.apply_projection) 
    args.independent_layers = str_to_bool(args.independent_layers)   
   

    main(args)