import argparse
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
import csv
import sys
from model import load_bert_model_tokenizer, ModelWithProj, load_projection
from data import BiosData, get_dataset_handler, prepare_tokenized_data
from utils import str_to_bool, set_seed, get_layers_to_process, est_Cov
import gc

# Import evaluation functions from eval_last_layer.py
from eval_last_layer import (
    TPR, TPR_gap, TPR_gap_multiclass, calc_wg_acc, 
    evaluate_predictions, save_results
)


def save_results(args, results_dict, result_file_name='results.csv'):
    """Save results to CSV file"""
    results_dir = Path("results") / "last_layer"
    results_dir.mkdir(parents=True, exist_ok=True)
    csv_path = results_dir / result_file_name
    
    # Define fieldnames for CSV
    fieldnames = results_dict.keys()
    
    # Create file with headers if it doesn't exist
    if (not csv_path.exists()):
        with open(csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
    
    # Append results
    with open(csv_path, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writerow(results_dict)



def get_predictions(model, tokenizer, texts, device, num_labels, batch_size):
    """Run model predictions efficiently"""
    model.eval()
    predictions = []
    
    # Calculate total batches for progress bar
    total_batches = (len(texts) + batch_size - 1) // batch_size
    
    # Process and predict one batch at a time
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), total=total_batches, desc="Processing batches"):
            # Get current batch of texts
            batch_texts = texts[i:i + batch_size]
            
            # Tokenize batch
            batch_tokens = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            # Move to device
            batch_inputs = {
                'input_ids': batch_tokens['input_ids'].to(device),
                'attention_mask': batch_tokens['attention_mask'].to(device)
            }
            
            # Get predictions
            logits = model(**batch_inputs).logits
            
            if num_labels > 2:
                preds = torch.argmax(logits, dim=1)
            else:
                preds = (torch.sigmoid(logits) > 0.5).long()
            
            predictions.extend(preds.cpu().numpy().tolist())

            # clean up memory
            del logits, batch_inputs, preds
            torch.cuda.empty_cache() if device == 'cuda' else gc.collect()
    
    return np.array(predictions)


def main(args):
    print(f"\nProcessing projection method: {args.proj_method}")
    
    # Load data
    data_handler = get_dataset_handler(args.dataset)
    data = data_handler.prepare_data( load_test=True, 
                embeddings=False, 
                sample=args.sample_data,
                p_y_z=args.p_y_z,
                p_y=0.5
            )

   
            
    if hasattr(data_handler, 'num_labels'):
        num_labels = data_handler.num_labels
    else:
        # Try to determine number of labels from data
        if isinstance(data['y_train'], np.ndarray) and len(data['y_train'].shape) > 1:
            num_labels = data['y_train'].shape[1]
        else:
            num_labels = len(np.unique(data['y_train']))
        
    print(f"Dataset: {args.dataset}, Number of labels: {num_labels}")
    
    # Load model and tokenizer - model_name already includes seed
    model_path = Path("models") / args.dataset / args.model_name
    
    print(f"Loading model from: {model_path}")
    torch.backends.cudnn.benchmark = True  # Enable cudnn autotuner
    model, tokenizer, device = load_bert_model_tokenizer(
        model_name=str(model_path),
        num_labels=num_labels,
        device=args.device,
        torch_dtype=torch.float16,
        freeze_base=True,
        freeze_all=True
    )
    model.to(device)


    # Apply projection if specified
    if args.proj_method != 'orig':
        print(f"Applying {args.proj_method} projection to model")
        
        # Create wrapper to apply projection
        model_with_proj = ModelWithProj(model, model_type="bert")
        
        # Determine layers to process
        layers_to_process = get_layers_to_process(args.layers, model, "bert")
        
        # Determine folder structure based on args.layers
        if args.layers == "all":
            layers_folder = "all"
        elif args.layers == "lm_head":
            layers_folder = "lm_head"
        elif args.layers.startswith("last_") and args.layers[5:].isdigit():
            layers_folder = args.layers
        else:
            layers_folder = args.layers
        
        # Load and apply projection for each layer
        for layer_id in layers_to_process:
            # Load the projection
            projection = load_projection(
                model_name=args.model_name,
                dataset=args.dataset,
                device=device,
                projection_method=args.proj_method,
                layer_id=layer_id,
                embedding_strategy=args.embedding_strategy,
                projections_dir=f"projections/{args.dataset}",
                layer_folder=layers_folder
            )
            
            # Register the projection hook
            model_with_proj.register_projection_hook(
                layer_id=layer_id-1,  # Adjust for 0-based index
                projection=projection, 
                apply_strategy=args.apply_strategy
            )

            print(f"Projection loaded for layer {layer_id} with strategy {args.apply_strategy}")
        
        # Use the model with projection
        model = model_with_proj
 

    # Run predictions
    test_texts = data['X_test']
    z_test = data['z_test']
    y_test = data['y_test']
    
    


    y_pred = get_predictions(
            model, 
            tokenizer, 
            test_texts, 
            device, 
            num_labels,
            batch_size=args.batch_size
        )

    # For group evaluation
    if hasattr(data_handler, 'get_group'):
        get_group_v = np.vectorize(data_handler.get_group)
        g = get_group_v(np.argmax(y_test, axis=1), z_test) if y_test.ndim > 1 else get_group_v(y_test, z_test)
    else:
        g = None
    
    # Evaluate predictions
    print("\nEvaluation Results:")
    print("-" * 50)
    
    acc, tpr, tpr_gap, wg_acc, acc_per_g, acc_per_class = evaluate_predictions(
        y_test, 
        y_pred, 
        z_test, 
        "Test Set", 
        g
    )
    
    # Prepare results for saving
    result = {
        'dataset': args.dataset,
        'model_name': args.model_name,
        'proj_method': args.proj_method,
        'layers': args.layers,
        'embedding_strategy': args.embedding_strategy,
        'apply_strategy': args.apply_strategy,
        'accuracy': acc,
        'tpr': tpr,
        'tpr_gap': tpr_gap,
        'worst_group_accuracy': wg_acc,
    }
    
    # Add accuracy per group if available
    if acc_per_g is not None:
        for g_val, acc in acc_per_g.items():
            result[f'acc_g{g_val}'] = acc
    
   
    
    # Save results
    if args.save:
        save_results(args, result, args.result_file)
        print(f"Results saved to results/eval/{args.result_file}")
    
    # Remove hooks if they were applied
    if args.proj_method != 'orig' and hasattr(model, 'remove_hooks'):
        model.remove_hooks()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate model with projections")
    
    # Dataset and model parameters
    parser.add_argument("--dataset", type=str, default="bios",
                      help="Dataset to evaluate (e.g., bios)")
    parser.add_argument("--model_name", type=str, required=True,
                      help="Model name or path (should include seed)")
    
    # Projection parameters
    parser.add_argument("--proj_method", type=str, default="orig",
                      choices=['LEACE', 'opt-sep-proj', 'orig', 'LEACE-no-whitening', 'SAL'],
                      help="Projection method to use")
    parser.add_argument("--layers", type=str, default="last",
                      help="Layers to adapt: 'last', 'all', 'lm_head', 'last_x', or specific layer number")
    parser.add_argument("--embedding_strategy", type=str, default="cls",
                      choices=['cls', 'mean', 'last', 'last_non_pad'],
                      help="Strategy for handling sequence dimension in embeddings")
    parser.add_argument("--apply_strategy", type=str, default="all",
                      choices=["all", "last_non_pad", "cls"],
                      help="Which tokens to apply the projection to")
    
    # Execution parameters
    parser.add_argument("--device", type=str, default="cpu",
                      help="Device to use for computation")
    parser.add_argument("--batch_size", type=int, default=32,
                      help="Batch size for processing")
    
    # Output parameters
    parser.add_argument("--save", type=str, default="False",
                      help="Whether to save results")
    parser.add_argument("--result_file", type=str, default="eval_results.csv",
                      help="Filename for saving results")
    
    parser.add_argument("--sample_data", type=str, default="False",
                      help="Whether to sample data")
    parser.add_argument("--p_y_z", type=float, default=0.5,
                      help="Probability for sampling data")
    
    args = parser.parse_args()
    
    # Convert string args to appropriate types
    args.save = str_to_bool(args.save)
    args.sample_data = str_to_bool(args.sample_data)
    
    main(args)
