import argparse
from re import M
import torch
from typing import Dict
import numpy as np
from datasets.test_loader import set_test_loader
import clip_w_local
from clip_w_local import clip
from tqdm import tqdm
from torch.nn import functional as F
from utils.eval_util import get_and_print_results, add_results, add_overall_results, save_results_to_json


def get_test_labels(in_dataset: str):
    if in_dataset == 'ImageNet':
        loc = "label_names/imagenet_class_clean.npy"
        with open(loc, 'rb') as f:
            imagemodel_cls = np.load(f)
    else:
        raise ValueError(f"Invalid dataset: {in_dataset}")
    return imagemodel_cls


def get_ood_scores(model, method, loader, test_labels, lambda_local: float = 0.5, T: float = 1.0):

    to_np = lambda x: x.data.cpu().numpy()
    concat = lambda x: np.concatenate(x, axis=0)
    _score = []
    tokenizer = clip.tokenize
    tqdm_object = tqdm(loader, total=len(loader))
    with torch.no_grad():
        for images, labels in tqdm_object:
            labels = labels.long().cuda()
            images = images.cuda()
            global_features, local_features = model.encode_image(images)  # .float()

            global_features = global_features.float()
            local_features = local_features.float()

            global_features /= global_features.norm(dim=-1, keepdim=True)
            local_features /= local_features.norm(dim=-1, keepdim=True)

            text_inputs = tokenizer([f"a photo of a {c}" for c in test_labels])
            text_features = model.encode_text(text_inputs.cuda()).float()
            text_features /= text_features.norm(dim=-1, keepdim=True)   
            output_global = global_features @ text_features.T
            output_local = local_features @ text_features.T

            smax_global = to_np(F.softmax(output_global/ T, dim=1))
            smax_local = to_np(F.softmax(output_local/ T, dim=-1))  # batch, grid, grid, class

            if method == 'mcm':
                _score.append(-np.max(smax_global, axis=1)) 
            elif method == 'gl-mcm':
                global_score = -np.max(smax_global, axis=1)
                local_score = -np.max(smax_local, axis=(1, 2))
                _score.append(global_score+lambda_local*local_score)
            else:
                raise NotImplementedError
    return concat(_score)[:len(loader.dataset)].copy()   



def main(args: argparse.Namespace) -> None:
    """Run evaluation"""
    print("Starting evaluation...")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load model
    model, preprocess = clip_w_local.load(args.model_name)
    model = model.to(device)
    model.eval()
    
    id_data_loader = set_test_loader(args.root, 'imagenet', preprocess, args.sample_size, args.batch_size, args.seed)
    test_labels = get_test_labels(args.in_dataset)
    # Calculate in-distribution scores
    in_score = get_ood_scores(model=model, method=args.method, loader=id_data_loader, test_labels=test_labels, lambda_local=args.lambda_value, T=args.T)
    
    # Lists for evaluation
    auroc_list, fpr_list = [], []

    results_data = []
    
    # Evaluate out-of-distribution datasets
    out_datasets = ['iNaturalist', 'SUN', 'places365', 'Texture']
    
    scores_dict: Dict[str, np.ndarray] = {}
    scores_dict["ImageNet"] = in_score

    
    for out_dataset in out_datasets:
        print(f"Evaluating OOD dataset: {out_dataset}")
        ood_data_loader = set_test_loader(args.root, out_dataset, preprocess, args.sample_size, args.batch_size, args.seed)
        out_score = get_ood_scores(model=model, method=args.method, loader=ood_data_loader, test_labels=test_labels, lambda_local=args.lambda_value, T=args.T)

        results = get_and_print_results(
            args, in_score, out_score,
            auroc_list, fpr_list
        )

        scores_dict[out_dataset] = out_score
        # Save results
        results_data = add_results(results_data, args.method, results, out_dataset)

    # add overall results to results_data. 正し今results_dataのtypeはlist of dict
    results_data = add_overall_results(results_data, args.method, auroc_list, fpr_list)

    # Save scores to .npz
    np.savez(f"{args.output_dir}/scores.npz", **scores_dict)

    # Save results to JSON
    save_results_to_json(results_data, args.output_dir, "results.json")
    print("Evaluation completed")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="", help="path to dataset")
    parser.add_argument("--output-dir", type=str, default="", help="output directory")
    parser.add_argument(
        "--seed", type=int, default=1, help="seed for random number data generator"
    )
    parser.add_argument(
        "--batch-size", type=int, default=500, help="batch size for test set"
    )
    parser.add_argument("--model_name", type=str, default='ViT-B/16',
                        choices=['ViT-B/16', 'RN50', 'RN101'], help='which pretrained img encoder to use')
    parser.add_argument("--in_dataset", type=str, default='ImageNet', help="name of in-distribution dataset")
    parser.add_argument('--method', type=str, default='gl-mcm', choices=['mcm', 'gl-mcm'], help='method type: mcm or gl-mcm')
    parser.add_argument('--T', type=float, default=1,
                        help='temperature for softmax')
    parser.add_argument('--lambda_value', type=float, default=0.5,
                        help='weight for regulization loss')
    parser.add_argument('--sample-size', type=int, default=500, help='sample size for test set')
    args = parser.parse_args()
    main(args) 