# -*- coding: utf-8 -*-
# @Time    : 2023/09/29
# @Author  : first author of the submitted paper of 'Generalization or Specificity? Spectral Meta Estimation and Ensemble (SMEE) with Domain-specific Experts' at ICLR24
# @File    : smee_demo.py
# All right reserved.
# requirements: numpy, scikit-learn, pytorch, and your model predictions, that's it.

import argparse
import random
import time

import torch
import numpy as np
from sklearn.metrics import accuracy_score
from torch.nn.functional import softmax


def SML_onevsrest(preds, labels):
    # SML-onevsrest
    start_time = time.time()
    preds = softmax(torch.from_numpy(preds), dim=-1).numpy()
    weights_all = []
    class_num = preds.shape[-1]
    for i in range(class_num):
        pred = np.zeros_like(preds)
        for j in range(len(preds)):
            pred[j, np.arange(len(preds[j])), preds[j].argmax(1)] = 1
        pred = pred[:, :, i]
        out = torch.mm(torch.from_numpy(pred), torch.from_numpy(pred).T)
        w, v = np.linalg.eig(out)
        accuracies = v[:, 0]
        total = np.sum(accuracies)
        weights = accuracies / total
        weights_all.append(weights)
    weights_final = np.sum(np.array(weights_all), axis=0)
    print('weights in the SML ensemble for the predictors:')
    print(weights_final)
    predictions = np.einsum('a,abc->bc', weights_final, preds)
    predict = np.argmax(predictions, axis=1)
    score = np.round(accuracy_score(labels, predict), 5)
    end_time = time.time()
    print('SML-onevsrest computation time at the end of test set in seconds:', end_time - start_time)
    print('SML-onevsrest: {:.2f}'.format(score * 100))
    return score * 100


def SML_onevsrest_online_vs_offline(preds_all, labels):
    # offline
    indices = np.arange(len(labels))
    random.shuffle(indices)
    preds_all = preds_all[:,indices,:]
    labels = labels[indices]
    preds = softmax(torch.from_numpy(preds_all), dim=-1).numpy()
    weights_all = []
    class_num = preds.shape[-1]
    for i in range(class_num):
        pred = np.zeros_like(preds)
        for j in range(len(preds)):
            pred[j, np.arange(len(preds[j])), preds[j].argmax(1)] = 1
        pred = pred[:, :, i]
        out = torch.mm(torch.from_numpy(pred), torch.from_numpy(pred).T)
        w, v = np.linalg.eig(out)
        accuracies = v[:, 0]
        total = np.sum(accuracies)
        weights = accuracies / total
        weights_all.append(weights)
    weights_final = np.sum(np.array(weights_all), axis=0)
    predictions = np.einsum('a,abc->bc', weights_final, preds)
    predict1 = np.argmax(predictions, axis=1)

    # online
    cnt = 0
    wrong_cnt = 0
    for sample_num in range(15, len(preds_all[0])):
        preds = np.copy(preds_all[:, :sample_num, :])
        if sample_num <= 15:  # 15 predictors
            # resort to averaging
            pred = np.average(preds, axis=0)
            predict2 = np.argmax(pred, axis=1)[-1]
        else:
            # SML
            preds = softmax(torch.from_numpy(preds), dim=-1).numpy()
            weights_all = []
            class_num = preds.shape[-1]
            for i in range(class_num):
                pred = np.zeros_like(preds)
                for j in range(len(preds)):
                    pred[j, np.arange(len(preds[j])), preds[j].argmax(1)] = 1
                pred = pred[:, :, i]
                out = torch.mm(torch.from_numpy(pred), torch.from_numpy(pred).T)
                w, v = np.linalg.eig(out)
                accuracies = v[:, 0]
                total = np.sum(accuracies)
                weights = accuracies / total
                if np.array_equal(weights, np.zeros(len(weights))):
                    continue
                weights_all.append(weights)
            weights_final = np.sum(np.array(weights_all), axis=0)
            predictions = np.einsum('a,abc->bc', weights_final, preds)
            predict2 = np.argmax(predictions, axis=1)[-1]

        if predict1[sample_num-1] == predict2:
            cnt += 1
        else:
            wrong_cnt += 1
    print('total same predictions', cnt)
    print('total different predictions', wrong_cnt)


def pred_voting(preds, labels):
    # voting
    n_classifier, n_samples, n_classes = preds.shape
    predict = np.argmax(preds, axis=2)
    votes_mat = np.zeros((n_classes, n_samples))
    for i in range(n_classifier):
        for j in range(n_samples):
            class_id = predict[i, j]
            votes_mat[class_id, j] += 1
    votes_pred = []
    for i in range(n_samples):
        pred = np.random.choice(np.flatnonzero(votes_mat[:, i] == votes_mat[:, i].max()))
        votes_pred.append(pred)
    votes_pred = np.array(votes_pred)
    score = np.round(accuracy_score(labels, votes_pred), 5)
    print('Voting: {:.2f}'.format(score * 100))
    return score * 100


def pred_averaging(preds, labels):
    # averaging
    preds = softmax(torch.from_numpy(preds), dim=-1).numpy()
    pred = np.average(preds, axis=0)
    predict = np.argmax(pred, axis=1)
    score = np.round(accuracy_score(labels, predict), 5)
    print('Averaging: {:.2f}'.format(score * 100))
    return score * 100


def pred_single(preds, labels):
    # single
    scores_arr = []
    for i in range(len(args.model_names)):
        pred = preds[i]
        predict = np.argmax(pred, axis=1)
        score = np.round(accuracy_score(labels, predict), 5)
        #print('Single Model {} Accuracy Score {:.2f}'.format(i + 1, score * 100))
        scores_arr.append(score * 100)
    return scores_arr


if __name__ == '__main__':

    print('meta ranking selection')

    '''
    SEED = 0
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    '''

    # number of models for each source expert selected to use in ensemble
    num_k = 5

    dataset_name = 'PACS'
    test_domain_name = 'Sketch'

    approaches = ['single_avg', 'voting', 'averaging', 'SML-onevsrest']

    print('Dataset name:', dataset_name)
    print('Test domain name:', test_domain_name)
    print('Number of models per each source:', num_k)

    scores = np.zeros((len(approaches)))

    num_source_domains = 4
    args = argparse.Namespace(dataset_name=dataset_name)
    args.model_names = ['ResNet-50'] * int(num_k * (num_source_domains - 1))

    preds = np.load('./example_preds.npy')
    labels = np.load('./example_labels.npy')

    print('(# predictors, # test samples, # classes) of logits outputs (before softmax)')
    print(preds.shape)
    assert preds.shape[1] == labels.shape[0]


    avg_single_scores = pred_single(preds, labels)
    scores[0] = np.average(avg_single_scores)
    p_avg_single_scores = []
    for o in range(len(avg_single_scores)):
        p_avg_single_scores.append(np.round(avg_single_scores[o], 2))
    print('Single model score:')
    print(p_avg_single_scores)
    print('Which source domain these models are from:')
    source_ids = [2, 1, 1, 0, 1, 2, 1, 2, 0, 2, 0, 1, 0, 0, 2]
    source_names = np.array(['A', 'C', 'P'])  # Art, Cartoon, Photo
    print(source_names[source_ids])

    scores[1] = pred_voting(preds, labels)
    scores[2] = pred_averaging(preds, labels)
    scores[3] = SML_onevsrest(preds, labels)

    print(approaches)
    print(np.round(scores, 3))

    # the following lines compares offline and online SML
    print('testing online vs offline SML result difference as in paper Appendix A.6')
    SML_onevsrest_online_vs_offline(preds, labels)

