import os
from util.args_loader import get_args
from util.model_loader import get_model
from util import metrics
import torch
import numpy as np
import torch.nn.functional as F


torch.set_default_tensor_type(torch.DoubleTensor)

from sklearn import svm
from sklearn.multiclass import OneVsRestClassifier

def train_svm(train_feat_log, train_label_log, val_feat_log, val_label_log):
    # Convert tensors to numpy arrays if necessary
    if isinstance(train_feat_log, torch.Tensor):
        train_feat_log = train_feat_log.detach().cpu().numpy()
    if isinstance(train_label_log, torch.Tensor):
        train_label_log = train_label_log.detach().cpu().numpy()
    if isinstance(val_feat_log, torch.Tensor):
        val_feat_log = val_feat_log.detach().cpu().numpy()
    if isinstance(val_label_log, torch.Tensor):
        val_label_log = val_label_log.detach().cpu().numpy()

    # Initialize SVM classifier
    # clf = OneVsRestClassifier(svm.SVC(kernel='linear'))
    clf = OneVsRestClassifier(svm.LinearSVC())

    # Fit the SVM classifier
    clf.fit(train_feat_log, train_label_log)

    # Initialize weight matrix and bias vector
    n_classes = len(clf.classes_)
    n_features = train_feat_log.shape[1]
    weight_matrix = np.zeros((n_classes, n_features))
    bias_vector = np.zeros(n_classes)
    
    # Fill in weights and biases
    for i, estimator in enumerate(clf.estimators_):
        if hasattr(estimator, 'coef_'):
            weight_matrix[i] = estimator.coef_.ravel()
        if hasattr(estimator, 'intercept_'):
            bias_vector[i] = estimator.intercept_

    preds = clf.predict(val_feat_log)
    acc = np.mean(preds == val_label_log)
    print(f"Validation accuracy: {acc}")
    print("weight shape: ", weight_matrix.shape)
    print("bias shape: ", bias_vector.shape)
    print("number of classes in training: ", n_classes)
    return clf, weight_matrix, bias_vector


def ORA(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                
                centered_feats = feats_batch_initial - torch.mean(class_means, dim=0)
                centered_feats_db = feats_batch_db - torch.mean(class_means, dim=0)

                norm_centered_feats = F.normalize(centered_feats, p=2, dim=1)
                norm_centered_feats_db = F.normalize(centered_feats_db, p=2, dim=1)

                cos_sim_origin_perspective = torch.sum(norm_centered_feats * norm_centered_feats_db, dim=1)
                angles_origin = torch.arccos(cos_sim_origin_perspective) / torch.pi


                # # fdbd original
                # distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                # fdbd_score = distance_to_db / torch.linalg.norm(feats_batch_initial - torch.mean(class_means, dim=0), dim=1)

                # # fdbd our derivation
                # feats_centered_db = feats_batch_initial - feats_batch_db
                # mean_centered_db = torch.mean(class_means, dim=0) - feats_batch_db
                # cos_sim = F.cosine_similarity(feats_centered_db, mean_centered_db, dim=1)
                # angles_db = torch.arccos(cos_sim) / torch.pi
                # our_derivation = torch.sin(angles_origin * torch.pi) / torch.sin(angles_db * torch.pi)

                # check our derivation is same as fdbd
                # fdbd_score[torch.isnan(fdbd_score)] = 0
                # our_derivation[torch.isnan(our_derivation)] = 0
                # # print(torch.allclose(fdbd_score, our_derivation))

                trajectory_list[:, class_id] = angles_origin
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.max(trajectory_list, dim=1).values
            # ood_score = torch.topk(trajectory_list, 2, largest=False, dim=1).values[:, 1]
            # ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores



def fDBD(
    model: torch.nn.Module,
    test_loader: torch.utils.data.DataLoader,
    num_classes: int,
    class_means: torch.Tensor,
) -> np.ndarray:
    model.eval()

    all_scores = []
    total_size = 0

    class_idx = np.arange(num_classes)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            preds_initial = logits_batch_initial.argmax(1)
            max_logits = logits_batch_initial.max(dim=1).values
            total_size += feats_batch_initial.size(0)
            trajectory_list = torch.zeros(feats_batch_initial.size(0), num_classes, device='cuda')
            for class_id in class_idx:
                
                logit_diff = max_logits - logits_batch_initial[:, class_id]
                weight_diff = model.fc.weight[preds_initial] - model.fc.weight[class_id]
                weight_diff_norm = torch.linalg.norm(weight_diff, dim=1)
                
                feats_batch_db = feats_batch_initial - torch.divide(logit_diff, weight_diff_norm**2).view(-1,1) * weight_diff
                distance_to_db = torch.linalg.norm(feats_batch_initial - feats_batch_db, dim=1)
                # fdbd
                trajectory_list[:, class_id] = distance_to_db / torch.linalg.norm(feats_batch_initial - torch.mean(class_means, dim=0), dim=1)
            
            trajectory_list[torch.isnan(trajectory_list)] = 0
            ood_score = torch.mean(trajectory_list, dim=1)
            all_scores.append(ood_score)
    scores = np.asarray(torch.cat(all_scores).detach().cpu().numpy(), dtype=np.float32)
    return scores


def knn_score(model, test_loader, num_classes, train_features, k=50):
    model.eval()
    all_scores = []
    norm_train_feats = F.normalize(train_features, p=2, dim=1)

    with torch.inference_mode():
        for feats_batch_initial, logits_batch_initial in test_loader:
            feats_batch_initial = feats_batch_initial.cuda()
            logits_batch_initial = logits_batch_initial.cuda()
            # normalize
            norm_feats_batch_initial = F.normalize(feats_batch_initial, p=2, dim=1)
            # calculate the distance
            # get the kth nearest neighbor distance
            distances = torch.cdist(norm_feats_batch_initial, norm_train_feats, p=2, compute_mode="donot_use_mm_for_euclid_dist")
            # get the kth nearest neighbor distance
            kth_distances = torch.topk(distances, k, largest=False)[0][:, -1]
            # breakpoint()
            all_scores.append(-kth_distances)
    scores = np.asarray(torch.cat(all_scores).cpu().numpy(), dtype=np.float32)
    return scores



torch.manual_seed(1)
torch.cuda.manual_seed(1)
np.random.seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = get_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

# load collected features

class_num = 10
id_train_size = 50000
id_val_size = 10000

cache_dir = f"cache/{args.in_dataset}_train_{args.name}_in"
feat_log = torch.from_numpy(np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(id_train_size, 512))).to(device)
score_log = torch.from_numpy(np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(id_train_size, class_num))).to(device)
label_log = torch.from_numpy(np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode='r', shape=(id_train_size,))).to(device)


cache_dir = f"cache/{args.in_dataset}_val_{args.name}_in"
feat_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/feat.mmap", dtype=float, mode='r', shape=(id_val_size, 512))).to(device)
score_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/score.mmap", dtype=float, mode='r', shape=(id_val_size, class_num))).to(device)
label_log_val = torch.from_numpy(np.memmap(f"{cache_dir}/label.mmap", dtype=float, mode='r', shape=(id_val_size,))).to(device)


ood_feat_score_log = {}
ood_dataset_size = {
    'SVHN':26032,
    'iSUN': 8925,
    'places365': 10000,
    'dtd': 5640
}

ood_feat_log_all = {}
for ood_dataset in args.out_datasets:
    ood_feat_log = torch.from_numpy(np.memmap(f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/feat.mmap", dtype=float, mode='r', shape=(ood_dataset_size[ood_dataset], 512))).to(device)
    ood_score_log = torch.from_numpy(np.memmap(f"cache/{ood_dataset}vs{args.in_dataset}_{args.name}_out/score.mmap", dtype=float, mode='r', shape=(ood_dataset_size[ood_dataset], class_num))).to(device)
    ood_feat_score_log[ood_dataset] = ood_feat_log, ood_score_log 



num_classes = 10
model = get_model(args, num_classes, load_ckpt=True)
model.to(device)

in_dataset = torch.utils.data.TensorDataset(feat_log_val.cpu(), score_log_val.cpu())
in_loader = torch.utils.data.DataLoader(in_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)


class_means = torch.zeros(num_classes, feat_log.size(1)).to(device)
for i in range(num_classes):
    class_means[i] = torch.mean(feat_log[label_log == i], dim=0).to(device)


# train_feats = feat_log
# train_scores = score_log

score_in = ORA(model, in_loader, num_classes, class_means)
# score_in = knn_score(model, in_loader, num_classes, train_feats, k=50)
# score_in = fDBD(model, in_loader, num_classes, class_means)

all_results = []
all_score_out = []

for ood_dataset_name, (feat_log, score_log) in ood_feat_score_log.items():
    ood_dataset = torch.utils.data.TensorDataset(feat_log.cpu(), score_log.cpu())
    # dataloader
    ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)

    scores_out_test = ORA(model, ood_loader, num_classes, class_means)
    # scores_out_test = knn_score(model, ood_loader, num_classes, train_feats, k=50)
    # scores_out_test = fDBD(model, ood_loader, num_classes, class_means)

    all_score_out.extend(scores_out_test)
    results = metrics.cal_metric(score_in, scores_out_test)
    all_results.append(results)
metrics.print_all_results(all_results, args.out_datasets, 'ORA')
