import os
from util.args_loader import get_args
from util.model_loader import get_model
from util import metrics
import torch
import numpy as np
import torchvision.models as models
import torch.nn.functional as F


torch.set_default_tensor_type(torch.DoubleTensor)

in_logger = []
out_logger = []


from sklearn import svm
from sklearn.multiclass import OneVsRestClassifier

def train_svm(train_feat_log, train_label_log, val_feat_log, val_label_log):
    # Convert tensors to numpy arrays if necessary
    if isinstance(train_feat_log, torch.Tensor):
        train_feat_log = train_feat_log.detach().cpu().numpy()
    if isinstance(train_label_log, torch.Tensor):
        train_label_log = train_label_log.detach().cpu().numpy()
    if isinstance(val_feat_log, torch.Tensor):
        val_feat_log = val_feat_log.detach().cpu().numpy()
    if isinstance(val_label_log, torch.Tensor):
        val_label_log = val_label_log.detach().cpu().numpy()

    # Initialize SVM classifier
    clf = OneVsRestClassifier(svm.SVC(kernel='linear'))

    # Fit the SVM classifier
    clf.fit(train_feat_log, train_label_log)

    # Initialize weight matrix and bias vector
    n_classes = len(clf.classes_)
    n_features = train_feat_log.shape[1]
    weight_matrix = np.zeros((n_classes, n_features))
    bias_vector = np.zeros(n_classes)
    
    # Fill in weights and biases
    for i, estimator in enumerate(clf.estimators_):
        if hasattr(estimator, 'coef_'):
            weight_matrix[i] = estimator.coef_.ravel()
        if hasattr(estimator, 'intercept_'):
            bias_vector[i] = estimator.intercept_

    preds = clf.predict(val_feat_log)
    acc = np.mean(preds == val_label_log)
    print(f"Validation accuracy: {acc}")
    print("weight shape: ", weight_matrix.shape)
    print("bias shape: ", bias_vector.shape)
    print("number of classes in training: ", n_classes)
    return clf, weight_matrix, bias_vector


def ORA(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                
                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi
                trajectory_list[:, class_id] = angles_origin
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def fDBD(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                # fdbd
                trajectory_list[:, class_id] = distance_to_db / torch.linalg.norm(feats_batch_initial - torch.mean(class_means, dim=0), dim=1)
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def knn_score(model, test_loader, num_classes, train_features, k=50):
    model.eval()
    all_scores = []
    norm_train_feats = F.normalize(train_features, p=2, dim=1)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            # normalize
            norm_feats_batch_initial = F.normalize(feats_batch_initial, p=2, dim=1)
            # calculate the distance
            # get the kth nearest neighbor distance
            distances = torch.cdist(norm_feats_batch_initial, norm_train_feats, p=2, compute_mode="donot_use_mm_for_euclid_dist")
            # get the kth nearest neighbor distance
            kth_distances = torch.topk(distances, k, largest=False)[0][:, -1]
            all_scores.append(-kth_distances)
    scores = np.asarray(torch.cat(all_scores).cpu().numpy(), dtype=np.float32)
    return scores


args = get_args()

if args.name == "resnet50" or args.name == "resnet50-supcon":
    feat_size = 2048
elif args.name == "vit":
    feat_size = 768

seed = args.seed
print(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

class_num = 1000
id_train_size = 1281167
id_val_size = 50000

cache_dir = f"cache/{args.in_dataset}_train_{args.name}_in"
feat_log = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/feat.mmap", dtype=float, mode="r", shape=(id_train_size, feat_size)
    )
).to(device)
score_log = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/score.mmap",
        dtype=float,
        mode="r",
        shape=(id_train_size, class_num),
    )
).to(device)
label_log = torch.from_numpy(
    np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode="r", shape=(id_train_size,))
).to(device)


cache_dir = f"cache/{args.in_dataset}_val_{args.name}_in"
feat_log_val = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/feat.mmap", dtype=float, mode="r", shape=(id_val_size, feat_size)
    )
).to(device)
score_log_val = torch.from_numpy(
    np.memmap(
        f"{cache_dir}/score.mmap", dtype=float, mode="r", shape=(id_val_size, class_num)
    )
).to(device)
label_log_val = torch.from_numpy(
    np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode="r", shape=(id_val_size,))
).to(device)


ood_feat_score_log = {}
ood_dataset_size = {"inat": 10000, "sun50": 10000, "places50": 10000, "dtd": 5640}

for ood_dataset in args.out_datasets:
    ood_feat_log = torch.from_numpy(
        np.memmap(
            f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/feat.mmap",
            dtype=float,
            mode="r",
            shape=(ood_dataset_size[ood_dataset], feat_size),
        )
    ).to(device)
    ood_score_log = torch.from_numpy(
        np.memmap(
            f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/score.mmap",
            dtype=float,
            mode="r",
            shape=(ood_dataset_size[ood_dataset], class_num),
        )
    ).to(device)
    ood_feat_score_log[ood_dataset] = ood_feat_log, ood_score_log


num_classes = 1000
model = get_model(
    args, num_classes, load_ckpt=True
)

# clf, weights, bias = train_svm(feat_log, label_log, feat_log_val, label_log_val)
# model.fc.weight.data = torch.from_numpy(weights).to(device)
# model.fc.bias.data = torch.from_numpy(bias).to(device)

in_dataset = torch.utils.data.TensorDataset(feat_log_val.cpu(), score_log_val.cpu())
in_loader = torch.utils.data.DataLoader(
    in_dataset, batch_size=512, shuffle=False, num_workers=2
)

class_means = torch.zeros(num_classes, feat_log.size(1)).to(device)
for i in range(num_classes):
    class_means[i] = torch.mean(feat_log[label_log == i], dim=0).to(device)


train_feats = feat_log
score_in = ORA(model, in_loader, num_classes, class_means)
# score_in = fDBD(model, in_loader, num_classes, class_means)
# score_in = knn_score(model, in_loader, num_classes, train_feats, k=1000)

all_results = []
all_score_out = []

for ood_dataset_name, (feat_log, score_log) in ood_feat_score_log.items():
    # create a dummy dataset over feat_log
    ood_dataset = torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu())
    # dataloader
    ood_loader = torch.utils.data.DataLoader(
        ood_dataset, batch_size=512, shuffle=False, num_workers=2
    )

    scores_out_test = ORA(model, ood_loader, num_classes, class_means)
    # scores_out_test = fDBD(model, ood_loader, num_classes, class_means)
    # scores_out_test = knn_score(model, ood_loader, num_classes, train_feats, k=1000)

    all_score_out.extend(scores_out_test)
    results = metrics.cal_metric(score_in, scores_out_test)
    all_results.append(results)
    print(results)

metrics.print_all_results(all_results, args.out_datasets, "ORA")
print()
