#!/usr/bin/env python3
"""
Proposed method script for ELCM (Entropy-Weighted Local Concept Matching).
This script implements the entropy-weighted local concept matching approach.
"""

import argparse
import torch
from typing import Dict
import numpy as np
from datasets.test_loader import set_test_loader
import clip_w_local
from clip_w_local import clip
from tqdm import tqdm
from torch.nn import functional as F
from utils.eval_util import get_and_print_results, add_results, add_overall_results, save_results_to_json
import os


def get_test_labels(in_dataset: str):
    if in_dataset == 'ImageNet':
        loc = "label_names/imagenet_class_clean.npy"
        with open(loc, 'rb') as f:
            imagemodel_cls = np.load(f)
    else:
        raise ValueError(f"Invalid dataset: {in_dataset}")
    return imagemodel_cls


def compute_entropy_weights(probs, alpha=1.0, eps=1e-8):
    """
    Compute entropy-based weights for local patches.
    
    Args:
        probs: Probability distribution over classes for each patch [batch, num_patches, num_classes]
        alpha: Weight for entropy scaling 
        eps: Small epsilon for numerical stability
        
    Returns:
        weights: Entropy-based weights [batch, num_patches]
    """
    # Compute Shannon entropy: H = -Σ p_c * log(p_c)
    log_probs = np.log(probs + eps)  # Add epsilon for numerical stability
    entropy = -np.sum(probs * log_probs, axis=-1)  # [batch, num_patches]
    
    # Convert to weights: w = exp(-α * H)
    weights = np.exp(-alpha * entropy)
    
    return weights


def get_ood_scores(model, method, loader, test_labels, lambda_local: float = 0.5, T: float = 1.0, alpha: float = 1.0):
    to_np = lambda x: x.data.cpu().numpy()
    concat = lambda x: np.concatenate(x, axis=0)
    _score = []
    tokenizer = clip.tokenize
    tqdm_object = tqdm(loader, total=len(loader))
    
    with torch.no_grad():
        for images, labels in tqdm_object:
            labels = labels.long().cuda()
            images = images.cuda()
            global_features, local_features = model.encode_image(images)  # .float()

            global_features = global_features.float()
            local_features = local_features.float()

            global_features /= global_features.norm(dim=-1, keepdim=True)
            local_features /= local_features.norm(dim=-1, keepdim=True)

            text_inputs = tokenizer([f"a photo of a {c}" for c in test_labels])
            text_features = model.encode_text(text_inputs.cuda()).float()
            text_features /= text_features.norm(dim=-1, keepdim=True)   
            output_global = global_features @ text_features.T
            output_local = local_features @ text_features.T

            smax_global = to_np(F.softmax(output_global/ T, dim=1))
            smax_local = to_np(F.softmax(output_local/ T, dim=-1))  # batch, h*w, class

            if method == 'mcm':
                _score.append(-np.max(smax_global, axis=1)) 
            elif method == 'gl-mcm':
                global_score = -np.max(smax_global, axis=1)
                local_score = -np.max(smax_local, axis=(1, 2))
                _score.append(global_score+lambda_local*local_score)
            elif method == 'elcm':
                # ELCM: Entropy-Weighted Local Concept Matching
                global_score = -np.max(smax_global, axis=1)
                
                # smax_local shape: [batch, num_patches, num_classes]
                # Compute entropy weights for each patch
                entropy_weights = compute_entropy_weights(smax_local, alpha=alpha)  # [batch, num_patches]
                
                # Get maximum class probability for each patch
                patch_max_probs = np.max(smax_local, axis=-1)  # [batch, num_patches]
                
                # Compute entropy-weighted local score
                weighted_local_scores = entropy_weights * patch_max_probs  # [batch, num_patches]
                local_score = -np.sum(weighted_local_scores, axis=1)  # Sum over patches: [batch]
                
                _score.append(global_score + lambda_local * local_score)
            else:
                raise NotImplementedError(f"Method {method} not implemented")
                
    return concat(_score)[:len(loader.dataset)].copy()   


def run_proposed_experiment(output_dir_path: str, alpha: float = 1.0) -> None:
    """
    Run the proposed ELCM experiment with specified parameters.
    
    Args:
        output_dir_path: Path to output directory
        alpha: Entropy weighting parameter
    """
    
    # Set up arguments that match the experiment.py requirements
    args = argparse.Namespace()
    
    # Core arguments for ELCM experiment
    args.root = "/datasets/LoCoOp"  # Path to dataset
    args.output_dir = output_dir_path
    args.seed = 1
    args.batch_size = 500
    args.model_name = 'ViT-B/16'
    args.in_dataset = 'ImageNet'
    args.method = 'elcm'  # Use ELCM method
    args.T = 1.0  # Temperature for softmax
    args.lambda_value = 0.5  # Weight for regularization loss (same as baseline)
    args.alpha = alpha  # Entropy weighting parameter
    args.sample_size = 500  # Sample size for test set
    
    print("Running ELCM experiment with parameters:")
    print(f"  Root: {args.root}")
    print(f"  Output directory: {args.output_dir}")
    print(f"  Seed: {args.seed}")
    print(f"  Batch size: {args.batch_size}")
    print(f"  Model: {args.model_name}")
    print(f"  In-distribution dataset: {args.in_dataset}")
    print(f"  Method: {args.method}")
    print(f"  Temperature (T): {args.T}")
    print(f"  Lambda value: {args.lambda_value}")
    print(f"  Alpha (entropy weighting): {args.alpha}")
    print(f"  Sample size: {args.sample_size}")
    print()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Run the main experiment
    main(args)


def main(args: argparse.Namespace) -> None:
    """Run evaluation"""
    print("Starting ELCM evaluation...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load model
    model, preprocess = clip_w_local.load(args.model_name)
    model = model.to(device)
    model.eval()
    
    id_data_loader = set_test_loader(args.root, 'imagenet', preprocess, args.sample_size, args.batch_size, args.seed)
    test_labels = get_test_labels(args.in_dataset)
    
    # Calculate in-distribution scores
    in_score = get_ood_scores(
        model=model, 
        method=args.method, 
        loader=id_data_loader, 
        test_labels=test_labels, 
        lambda_local=args.lambda_value, 
        T=args.T, 
        alpha=args.alpha
    )
    
    # Lists for evaluation
    auroc_list, fpr_list = [], []
    results_data = []
    
    # Evaluate out-of-distribution datasets
    out_datasets = ['iNaturalist', 'SUN', 'places365', 'Texture']
    
    scores_dict: Dict[str, np.ndarray] = {}
    scores_dict["ImageNet"] = in_score
    
    for out_dataset in out_datasets:
        print(f"Evaluating OOD dataset: {out_dataset}")
        ood_data_loader = set_test_loader(args.root, out_dataset, preprocess, args.sample_size, args.batch_size, args.seed)
        out_score = get_ood_scores(
            model=model, 
            method=args.method, 
            loader=ood_data_loader, 
            test_labels=test_labels, 
            lambda_local=args.lambda_value, 
            T=args.T, 
            alpha=args.alpha
        )

        results = get_and_print_results(
            args, in_score, out_score,
            auroc_list, fpr_list
        )

        scores_dict[out_dataset] = out_score
        # Save results
        results_data = add_results(results_data, args.method, results, out_dataset)

    # Add overall results to results_data
    results_data = add_overall_results(results_data, args.method, auroc_list, fpr_list)

    # Save scores to .npz
    np.savez(f"{args.output_dir}/scores.npz", **scores_dict)

    # Save results to JSON
    save_results_to_json(results_data, args.output_dir, "results.json")
    print("ELCM evaluation completed")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ELCM experiment script")
    parser.add_argument("--output-dir", type=str, required=True, help="path to output directory where results will be saved")
    parser.add_argument("--alpha", type=float, default=1.0, help="entropy weighting parameter (default: 1.0)")
    
    args = parser.parse_args()
    
    # Run the proposed experiment
    run_proposed_experiment(args.output_dir, args.alpha)