# AD Statistics  
# Calculating the required stats for comparing explanations  
# 1. IOU/DICE/PDC/Jenson-Shannon Divergence/Wasserstein Distance (Done)
# 2. Per-Pixel Contribution (Done)
# 3. Overlap Across Seeds (For the same value, see if it is seed robust)
# 4. Overlap Across Thresholds (For the same value, see if localisation if robust)
# 5. Robustness across background color. (Done)

import os
os.environ["HF_HOME"] = "../models"
import timm
import argparse
import json
import numpy as np
import torch
import torch.nn as nn
from utils import calculate_all_metrics, robustness_metrics

import torchvision.transforms as T
from torchvision.models import get_model, ResNet50_Weights


from interpretable_resnet_torchvision import InterpretableResNet50
from interpretable_regnety import InterpretableRegNetY
from interpretable_efficientnetv2 import InterpretableEfficientNetV2
from interpretable_mobilenetv4 import InterpretableMobileNetV4

BASE_THRESHOLD_DIR = None
CHECKPOINT_DIR_IN_1k = BASE_THRESHOLD_DIR_IN_1k = "../ImageNet-onek"
CHECKPOINT_DIR_CT_256 = BASE_THRESHOLD_DIR_CT_256 = "../CalTech-256"
CHECKPOINT_DIR_IN_1k_v2 = BASE_THRESHOLD_DIR_IN_1k_v2 = "../IN-1k_v2"
CHECKPOINT_DIR_PASCAL_VOC = BASE_THRESHOLD_DIR_PASCAL_VOC = "../PASCAL-VOC"
DATASET_PATH_IN_1k = "../ImageNet-onek/IN-onek_data/Test"
DATASET_PATH_CT_256 = '../CalTech-256/Dataset/exp_test'
DATASET_PATH_IN_1k_v2 = '../IN-1k_v2/Dataset'
DATASET_PATH_PASCAL_VOC = '../PASCAL-VOC/Dataset/exp_test'

# Ensure JSON serializable (convert numpy types)
def _json_default(obj):
    if isinstance(obj, np.generic):
        return obj.item()
    # Convert numpy arrays if any appear
    if hasattr(obj, 'tolist'):
        try:
            return obj.tolist()
        except Exception:
            pass
    return str(obj)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="IN-1k", choices=["IN-1k", "CT-256", "IN-1k_v2", "PASCAL-VOC"])
    parser.add_argument("--model", type=str, default="resnet", choices=["resnet", "regnet", "efn_v2", "mbnet_v4"])
    args = parser.parse_args()


    if args.dataset == "IN-1k":
        BASE_THRESHOLD_DIR = BASE_THRESHOLD_DIR_IN_1k
        CHECKPOINT_DIR = CHECKPOINT_DIR_IN_1k
        DATASET_PATH = DATASET_PATH_IN_1k
    elif args.dataset == "IN-1k_v2":
        BASE_THRESHOLD_DIR = BASE_THRESHOLD_DIR_IN_1k_v2
        CHECKPOINT_DIR = CHECKPOINT_DIR_IN_1k_v2
        DATASET_PATH = DATASET_PATH_IN_1k_v2
    elif args.dataset == "CT-256":
        BASE_THRESHOLD_DIR = BASE_THRESHOLD_DIR_CT_256
        CHECKPOINT_DIR = CHECKPOINT_DIR_CT_256
        DATASET_PATH = DATASET_PATH_CT_256
    elif args.dataset == "PASCAL-VOC":
        BASE_THRESHOLD_DIR = BASE_THRESHOLD_DIR_PASCAL_VOC
        CHECKPOINT_DIR = CHECKPOINT_DIR_PASCAL_VOC
        DATASET_PATH = DATASET_PATH_PASCAL_VOC

    if args.model == "resnet":
        BASE_THRESHOLD_DIR = os.path.join(BASE_THRESHOLD_DIR, "Results_resnet")
        model = get_model('resnet50', weights="DEFAULT")
        if args.dataset == 'IN-1k' or args.dataset == 'IN-1k_v2':
            model.transforms = ResNet50_Weights.DEFAULT.transforms()
            ad_model = InterpretableResNet50()
        elif args.dataset == 'CT-256':
            model.fc = nn.Sequential(
                nn.Linear(model.fc.in_features, 512),
                nn.ReLU(),
                nn.Linear(512, 257)
            )

            model.transforms = T.Compose([
                T.CenterCrop((224, 224)),
                T.Resize((232, 232)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
           
        
        elif args.dataset == 'PASCAL-VOC':
            model.fc = nn.Sequential(
                nn.Linear(model.fc.in_features, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, 20)
            )

            model.transforms = T.Compose([
                T.Resize((232, 232)),
                T.CenterCrop((224, 224)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

            model_weights = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoints_resnet", "final_model.pth"), map_location="cpu")["model_state_dict"]
            model.load_state_dict(model_weights)
            ad_model = InterpretableResNet50(caltech256=True if args.dataset == 'CT-256' else False, pascal_voc=True if args.dataset == 'PASCAL-VOC' else False)
            ad_model.model.load_state_dict(model_weights)

    elif args.model == "regnet":
        BASE_THRESHOLD_DIR = os.path.join(BASE_THRESHOLD_DIR, "Results_regnet")
        model = timm.create_model('regnety_120.sw_in12k_ft_in1k', pretrained=True)
        if args.dataset == 'IN-1k' or args.dataset == 'IN-1k_v2':
            ad_model = InterpretableRegNetY()
            data_config = timm.data.resolve_model_data_config(model)
            model.transforms = timm.data.create_transform(**data_config, is_training=False)
        elif args.dataset == 'CT-256' or args.dataset == 'PASCAL-VOC':
            num_classes = 257 if args.dataset == 'CT-256' else 20
            model.head.fc = nn.Sequential(
                nn.Linear(model.head.fc.in_features, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, num_classes)
            )

            model.transforms = T.Compose([
                T.Resize((384, 384), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
                T.CenterCrop((384, 384)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

            model_weights = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoints_regnet", "final_model.pth"), map_location="cpu")["model_state_dict"]
            model.load_state_dict(model_weights)
            ad_model = InterpretableRegNetY(caltech256=True if args.dataset == 'CT-256' else False, pascal_voc=True if args.dataset == 'PASCAL-VOC' else False)
            ad_model.model.load_state_dict(model_weights)

    elif args.model == "efn_v2":
        BASE_THRESHOLD_DIR = os.path.join(BASE_THRESHOLD_DIR, "Results_efn_v2")
        model = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=True)
        if args.dataset == 'IN-1k' or args.dataset == 'IN-1k_v2':
            ad_model = InterpretableEfficientNetV2()
            data_config = timm.data.resolve_model_data_config(model)
            model.transforms = timm.data.create_transform(**data_config, is_training=False)
        elif args.dataset == 'CT-256' or args.dataset == 'PASCAL-VOC':
            num_classes = 257 if args.dataset == 'CT-256' else 20
            model.classifier = nn.Sequential(
                nn.Linear(model.classifier.in_features, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, num_classes)
            )

            model.transforms = T.Compose([
                T.Resize((300, 300), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
                T.CenterCrop((300, 300)),
                T.ToTensor(),
                T.Normalize(mean=[0.500, 0.500, 0.500], std=[0.500, 0.500, 0.500])
            ])

            model_weights = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoints_efn_v2", "final_model.pth"), map_location="cpu")["model_state_dict"]
            model.load_state_dict(model_weights)
            ad_model = InterpretableEfficientNetV2(caltech_256=True if args.dataset == 'CT-256' else False, pascal_voc=True if args.dataset == 'PASCAL-VOC' else False)
            ad_model.model.load_state_dict(model_weights)

    elif args.model == "mbnet_v4":
        BASE_THRESHOLD_DIR = os.path.join(BASE_THRESHOLD_DIR, "Results_mbnet_v4")
        model = timm.create_model('mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', pretrained=True)
        if args.dataset == 'IN-1k' or args.dataset == 'IN-1k_v2':
            ad_model = InterpretableMobileNetV4()
            data_config = timm.data.resolve_model_data_config(model)
            model.transforms = timm.data.create_transform(**data_config, is_training=False)
        elif args.dataset == 'CT-256' or args.dataset == 'PASCAL-VOC':
            num_classes = 257 if args.dataset == 'CT-256' else 20
            model.classifier = nn.Sequential(
                nn.Linear(model.classifier.in_features, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, num_classes)
            )

            model.transforms = T.Compose([
                T.Resize((448, 448), interpolation=T.InterpolationMode.BICUBIC, antialias=True),
                T.CenterCrop((448, 448)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
            model_weights = torch.load(os.path.join(CHECKPOINT_DIR, "checkpoints_mbnet_v4", "final_model.pth"), map_location="cpu")["model_state_dict"]
            model.load_state_dict(model_weights)
            ad_model = InterpretableMobileNetV4(caltech_256=True if args.dataset == 'CT-256' else False, pascal_voc=True if args.dataset == 'PASCAL-VOC' else False)
            ad_model.model.load_state_dict(model_weights)


    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model.to(device)
    model.eval()

    ad_model.to(device)
    ad_model.eval()

    #0.9 already completed
    calculate_all_metrics(thresholds = [0.9,0.7, 0.5, 0.3, 0.1, 0],
                        seeds = [42],
                        num_backgrounds = 100,
                        base_dir = BASE_THRESHOLD_DIR,
                        columns = ['Max vs Zero', 'Max vs Mean', 'Max vs Min', ' Max vs AD', 'Zero vs Mean', 'Zero vs Min', 'Zero vs AD', 'Mean vs Min', 'Mean vs AD', 'Min vs AD'],
                        per_pixel_columns = ['Max', 'Zero', 'Mean', 'Min', 'AD'],
                        dataset_path = DATASET_PATH,
                        model = model,
                        ad_model = ad_model,
                        device = device)


    for threshold in [0.9, 0.7, 0.5, 0.3, 0.1, 0]:
        threshold_dir = os.path.join(BASE_THRESHOLD_DIR, f"Threshold_{threshold}")
        beta_goodness_correct, beta_goodness_incorrect = robustness_metrics(threshold_dir,
                                                                            num_backgrounds = 100)
        # Save the beta_goodness dictionaries into the threshold directory
        correct_out = os.path.join(threshold_dir, "beta_goodness_correct.json")
        incorrect_out = os.path.join(threshold_dir, "beta_goodness_incorrect.json")

        with open(correct_out, 'w') as f:
            json.dump(beta_goodness_correct, f, indent=2, default=_json_default)

        with open(incorrect_out, 'w') as f:
            json.dump(beta_goodness_incorrect, f, indent=2, default=_json_default)
