import os
from util.args_loader import get_args
from util.data_loader import get_loader_in, get_loader_out
from util.model_loader import get_model
from util import metrics
import torch
import numpy as np
import torchvision.models as models
import torch.nn.functional as F


torch.set_default_tensor_type(torch.DoubleTensor)


def ORA(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                
                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi

                trajectory_list[:, class_id] = angles_origin
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores



def fDBD(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                # fdbd
                trajectory_list[:, class_id] = distance_to_db / torch.linalg.norm(feats_batch_initial - torch.mean(class_means, dim=0), dim=1)
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


args = get_args()

if args.name == "resnet50" or args.name == "resnet50-supcon":
    feat_size = 2048
elif args.name == "vit":
    feat_size = 768

seed = args.seed
print(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

class_num = 1000
id_train_size = 1281167
id_val_size = 50000

method = "ORA"

ood_feat_score_log = {}
ood_dataset_size = {"inat": 10000, "sun50": 10000, "places50": 10000, "dtd": 5640}


scores_in_resnet = np.load(f"./scores/{method}/{args.in_dataset}_resnet50_score_in_{method}.npy")
scores_in_resnet_supcon = np.load(f"./scores/{method}/{args.in_dataset}_resnet50-supcon_score_in_{method}.npy")
scores_in_vit = np.load(f"./scores/{method}/{args.in_dataset}_vit_score_in_{method}.npy")


score_in = scores_in_resnet + scores_in_resnet_supcon + scores_in_vit

all_results = []
all_score_out = []
for ood_dataset in args.out_datasets:

    scores_out_test_resnet = np.load(f"./scores/{method}/{args.in_dataset}_{ood_dataset}_resnet50_score_out_{method}.npy")
    scores_out_test_resnet_supcon = np.load(f"./scores/{method}/{args.in_dataset}_{ood_dataset}_resnet50-supcon_score_out_{method}.npy")
    scores_out_test_vit = np.load(f"./scores/{method}/{args.in_dataset}_{ood_dataset}_vit_score_out_{method}.npy")
    scores_out_test = scores_out_test_resnet + scores_out_test_resnet_supcon + scores_out_test_vit

    all_score_out.extend(scores_out_test)
    results = metrics.cal_metric(score_in, scores_out_test)
    all_results.append(results)
    print(results)
metrics.print_all_results(all_results, args.out_datasets, "ORA")
print()
