from copy import deepcopy
import os
import argparse
import sklearn
import numpy as np
import ot
import torch
import torch.nn.functional as F
import torch.utils.data as data
from tqdm import tqdm

import common.utils.logging as logging
from common.utils.analysis.a_distance import ANet
from common.utils.meter import AverageMeter
from dalib.adaptation.dann import DomainAdversarialLoss as DANNLoss
from dalib.adaptation.mcd import MCD
from dalib.adaptation.mdd import MDD
from dalib.modules import entropy
from utils.clustering import class_centroids, Clustering, kmeans_clustering
from utils.ISM import logistic_metrics, mlp_metrics
from utils.eval_utils import compute_learnable_metric, compute_learnable_metric_k_fold, ConcatLoader
from utils.common_utils import merge_excels

logger = logging.get_logger(__name__)

def compute_metrics(cfg, loader, model=None, k_fold=0):
    metric_scores = {}

    if 'clustering_l2' in cfg.metrics:
        source_centroids = class_centroids(loader['source_test_data'], cfg.num_classes)
        _, _, center_cost, _ = \
                    Clustering(dist_type='euclidean', max_iter=0).feature_clustering(loader['target_test_data'], source_centroids)
        metric_scores['centroids_cost_l2'] = 1-center_cost # the cost of hard pseudo labels (assignments)
        if 'office' in cfg.data.lower():
            target_l2_centroids, target_cluster_labels, center_cost, centroids_change = \
                        kmeans_clustering(loader['target_test_data'], source_centroids, dist_type='euclidean', realign=True)
        else:
            target_l2_centroids, target_cluster_labels, center_cost, centroids_change = \
                        Clustering(dist_type='euclidean').feature_clustering(loader['target_test_data'], source_centroids)
        print("l2 clusting finished.")
        metric_scores['centroids_change_l2'] = 1-centroids_change
        metric_scores['center_cost_l2'] = 1-center_cost
        # metric_scores['silhouette_l2'] = sklearn.metrics.silhouette_score(loader['target_test_data'][0], target_cluster_labels, metric='euclidean')
        pred_labels = np.argmax(loader['target_test_data'][1],axis=1)
        metric_scores['ClassAMI'] = sklearn.metrics.adjusted_mutual_info_score(pred_labels, target_cluster_labels)

    if 'clustering_cos' in cfg.metrics:
        source_cos_centroids = class_centroids(loader['source_test_normalized_data'], cfg.num_classes)
        _, _, center_cost, _ = \
                    Clustering(dist_type='cos', max_iter=0).feature_clustering(loader['target_test_normalized_data'], source_cos_centroids)
        metric_scores['centroids_cost_cos'] = 1-center_cost # the cost of hard pseudo labels (assignments)
        if 'office' in cfg.data.lower():
            target_cos_centroids, target_cluster_labels, center_cost, centroids_change = \
                        kmeans_clustering(loader['target_test_normalized_data'], source_cos_centroids, dist_type='cos', realign=True)
        else:
            target_cos_centroids, target_cluster_labels, center_cost, centroids_change = \
                    Clustering(dist_type='cos').feature_clustering(loader['target_test_data'], source_cos_centroids)
        print("cos clusting finished.")
        metric_scores['centroids_change_cos'] = centroids_change
        metric_scores['center_cost_cos'] = center_cost # equivalent to the varience of each cluster
        metric_scores['silhouette_cos'] = sklearn.metrics.silhouette_score(loader['target_test_normalized_data'][0], target_cluster_labels, metric='cosine')
        pred_labels = np.argmax(loader['target_test_normalized_data'][1],axis=1)
        metric_scores['ClassAMI_cos'] = sklearn.metrics.adjusted_mutual_info_score(pred_labels, target_cluster_labels)

    if 'mlp_metrics' in cfg.metrics or 'ACM' in cfg.metrics:
        logger.info("Begin evaluating mlp_metrics")
        metric_scores["mlp_entropy"], metric_scores["mlp_IS"], \
            metric_scores["mlp_sourceacc"], metric_scores["mlp_acc"], mlp_clf = \
                mlp_metrics(loader['source_test_data'], loader['target_test_data'], hidden_layer_sizes=(cfg.bottleneck_dim,), return_clf=True)
        metric_scores["ISM"] = (1+metric_scores["mlp_IS"] /np.log(args.num_classes))/2 + metric_scores['source_accuracy']

    target_features, target_outputs, _ = loader['target_test_data']
    loader['metric_pesudolabeled'] = data.DataLoader(data.TensorDataset(target_features, target_outputs, torch.argmax(target_outputs, dim=1)),
                        batch_size=cfg.metric_batch_size, shuffle=True, num_workers=cfg.workers, drop_last=False)

    # compute each metric
    for m in cfg.metrics:
        meter = AverageMeter(m, ":4.2f")
        metric = None
        net = None
        m_loader = ConcatLoader(loader['metric_source'], loader['metric'], batch_size_scale=1)
        if m == 'entropy':
            probs = F.softmax(loader['target_test_data'][1], dim=1)
            metric_scores[m] = - entropy(probs).mean()
            metric_scores['diversity'] = (-torch.log(probs.mean(0)+1e-5)*probs.mean(0)).sum()
            metric_scores['MI'] = metric_scores['entropy'] + metric_scores['diversity']
            metric_scores['maxsquare'] = (probs*probs).sum(1).mean()
        elif m == 'logistic_metrics':
            logger.info("Begin evaluating logistic_metrics")
            metric_scores["logistic_entropy"], metric_scores["logistic_IS"], \
                metric_scores["logistic_sourceacc"], metric_scores["logistic_acc"] = \
                logistic_metrics(loader['source_test_data'], loader['target_test_data'])
        elif m == 'ACM' and model is not None:
            meter2 = AverageMeter(m, ":4.2f")
            with torch.no_grad():
                for (fea1, out1, _label1), (fea2, out2, _label2) in tqdm(zip(loader["metric"], loader["metric_aug"]), 
                            total=len(loader["metric"]), desc="evaluating target consistency"):
                    assert torch.all(_label1==_label2)
                    meter.update((torch.argmax(out1, 1)==torch.argmax(out2, 1)).float().mean(), fea1.shape[0])
                    mlp_pred1 = mlp_clf.predict(fea1.cpu().numpy())
                    mlp_pred2 = mlp_clf.predict(fea2.cpu().numpy())
                    meter2.update(np.mean((mlp_pred1==mlp_pred2).astype(np.float32)), fea1.shape[0])
            metric_scores['consist'] = meter.avg
            metric_scores['AC'] = meter2.avg
            metric_scores[m] = metric_scores['AC']/2 \
                                + (1+(metric_scores['mlp_IS']-metric_scores['mlp_entropy'])/np.log(args.num_classes))/2 \
                                    + metric_scores['source_accuracy']
        elif m == 'SND':
            if "office" in cfg.data.lower():
                large_loader = [(loader['source_test_data'], loader['target_test_data'])]
            else:
                large_loader = ConcatLoader(loader['metric_source'], loader['metric'], batch_size_scale=32)
            # from https://github.com/VisionLearningGroup/SND/blob/main/nc_ps/eval.py
            def neighbor_density(feature, T=0.05):
                feature = F.normalize(feature)
                mat = torch.matmul(feature, feature.t()) / T
                mask = torch.eye(mat.size(0), mat.size(0)).bool()
                mat.masked_fill_(mask, -1 / T)
                result = entropy(F.softmax(mat, dim=1)).mean()
                return result
            for (_, _, _), (_, g_t, _) in tqdm(large_loader, total=len(large_loader), desc="evaluating SND"):
                probs = F.softmax(g_t, dim=1)
                score = neighbor_density(probs)
                meter.update(score, g_t.shape[0])
            metric_scores[m] = meter.avg
        elif m == 'BNM':
            for fea, out, _ in tqdm(loader['metric'], total=len(loader['metric']), desc="evaluating BNM"):
                if fea.size(0) < cfg.metric_batch_size/2: break
                prob = F.softmax(out, dim=1)
                loss = torch.linalg.norm(prob, "nuc") / prob.shape[0]
                meter.update(loss, fea.shape[0])
            metric_scores[m] = meter.avg
        elif m == 'DEV':
            from utils.dev import compute_dev
            logger.info("Begin evaluating Deep Embedded Validation")
            metric_scores[m], metric_scores['IWCV'] = compute_dev(loader['source_test_data'], loader['target_test_data'])
            metric_scores['DEVN'], _ = compute_dev(loader['source_test_data'], loader['target_test_data'], normalization="standardize")

        elif m == 'split_accuracy':
            for _, out, label in loader['metric_target_test']:
                pred = torch.argmax(out, 1).float()
                acc = torch.sum((pred == label).float()) / float(out.size(0))
                meter.update(acc, out.size(0))
            metric_scores[m] = meter.avg
        elif m == 'a_distance':
            net = ANet(cfg.bottleneck_dim)
            metric = DANNLoss(net).cuda()
            if k_fold > 0:
                a_distance = compute_learnable_metric_k_fold(cfg, loader['source_test_data'], loader['target_test_data'],
                                                metric, meter, n_splits=k_fold, lr=0.01, epochs=cfg.metric_train_epochs)
            else:
                a_distance = compute_learnable_metric(loader['metric_train'], loader['metric_test'],
                                                metric, meter, lr=0.01, epochs=cfg.metric_train_epochs)
            metric_scores[m] = metric_scores['source_accuracy'] + 1-a_distance  # 4 * acc - 2
        elif m == 'MCD':
            metric = MCD(cfg.bottleneck_dim, cfg.num_classes, training_classifer=False).cuda()
            if k_fold > 0:
                MCD_distance = compute_learnable_metric_k_fold(cfg, loader['source_test_data'], loader['target_test_data'],
                                                    metric, meter, n_splits=k_fold, lr=0.004, epochs=cfg.metric_train_epochs)
            else:
                MCD_distance = compute_learnable_metric(loader['metric_train'], loader['metric_test'],
                                                    metric, meter, lr=0.004, epochs=cfg.metric_train_epochs)
            metric_scores[m] = metric_scores['source_accuracy'] + 1-MCD_distance
        elif m == 'MDD':
            metric = MDD(cfg.bottleneck_dim, cfg.num_classes, training_classifer=True).cuda()
            if k_fold > 0:
                MDD_distance = compute_learnable_metric_k_fold(cfg, loader['source_test_data'], loader['target_test_data'],
                                                    metric, meter, n_splits=k_fold, lr=0.004, epochs=cfg.metric_train_epochs)
            else:
                MDD_distance = compute_learnable_metric(loader['metric_train'], loader['metric_test'],
                                                    metric, meter, lr=0.004, epochs=cfg.metric_train_epochs)
            metric_scores[m] = metric_scores['source_accuracy'] + 1-MDD_distance
        if metric is not None: del metric
        if net is not None: del net
    torch.cuda.empty_cache()
    return metric_scores

def image_classification_test(cfg, loader, model, train_metric):
    all_loss = AverageMeter("test_loss", ":4.2f")
    all_feature = []
    all_output = []
    all_label = []
    with torch.no_grad():
        for inputs, labels in tqdm(loader, total=len(loader), desc="evaluating target test accuracy"):
            outputs, feature = model(inputs.cuda())
            if train_metric is None:
                loss = torch.tensor(0.0).cuda()
            elif cfg.method.lower() in ['dann', 'cdan']:
                loss = train_metric(outputs, feature, outputs, feature, d_label=torch.zeros((2*feature.size(0), 1)).cuda())
            elif cfg.method.lower() == "mcc":
                loss = train_metric(outputs)
            elif cfg.method.lower() == "proto":
                prototypes = model.head.weight.data.clone()
                loss = train_metric(prototypes, feature)
            elif cfg.method.lower() in ["consist", "mdd"]:
                loss = torch.tensor(0.0).cuda()
            all_loss.update(loss.cpu(), feature.size(0))
            all_feature.append(feature.cpu())
            all_output.append(outputs.cpu())
            all_label.append(labels)
    all_feature = torch.cat(all_feature, 0)
    all_output = torch.cat(all_output, 0)
    all_label = torch.cat(all_label, 0)
    predict = torch.argmax(all_output, 1)
    accuracy = torch.sum((predict == all_label).float()) / float(all_label.size(0))
    return accuracy.item(), all_loss.avg, (all_feature, all_output, all_label)

if __name__ == '__main__':
    import pickle
    import pandas as pd
    from scripts.utils import get_dataset_names
    from utils.eval_utils import construct_metric_dataloader
    from common.utils.metric import ConfusionMatrix
    parser = argparse.ArgumentParser(description='evaluation for Unsupervised Domain Adaptation')
    parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=get_dataset_names(),
                        help='dataset: ' + ' | '.join(get_dataset_names()) + ' (default: Office31)')
    parser.add_argument('--method', type=str, default='CDAN', help="The DA training method")
    parser.add_argument('--metric_batch_size', default=64, type=int, help='mini-batch size for metrics')
    parser.add_argument('--metric_train_epochs', default=10, type=int, help='train_epochs for metrics')
    parser.add_argument('--k_fold', default=0, type=int, help='k_fold for cross validation ')
    parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 2)')
    parser.add_argument('--per-class-eval', action='store_true',
                        help='whether output per-class accuracy during evaluation')
    parser.add_argument("--log", type=str, default='/tmp/DAmetric_lib_logs/cdan',
                        help="Where to save logs, checkpoints and debugging images.")
    parser.add_argument('--output_path', type=str, 
                        default='/home/username/DAmetric_logs/officehome_cdan_A2C_newer', 
                        help="The log dir")
    args = parser.parse_args()
    general_metrics = ['accuracy', 'test_loss', 'source_accuracy', 'split_accuracy']
    discrepancy_metrics = ['a_distance', 'MCD', 'MDD'] 
    assign_cost_metrics = ['entropy', 'clustering_l2', 'clustering_cos', 'mlp_metrics'] 
    image_metrics = ['ACM']
    other_metrics = ['DEV', 'SND', 'BNM']
    all_metrics = general_metrics + discrepancy_metrics + assign_cost_metrics + image_metrics + other_metrics
    
    def evaluate_one(log_path):
        result_file = os.path.join(log_path, "metric_scores.xlsx")
        if os.path.exists(result_file):
            print("successfully load result file:", result_file)
        else:
            return
        try:
            df = pd.read_excel(result_file, engine='openpyxl')
        except BaseException:
            print("error excel:", result_file)
            return
        logging.setup_logging(log_path)

        args.metrics = [m for m in all_metrics if m not in df.columns]

        if 'office' in args.data.lower():
            args.k_fold = 3
        elif 'visda' in args.data.lower():
            args.metric_train_epochs = 3
            args.k_fold = 0
        elif 'domainnet' in args.data.lower():
            args.k_fold = 0
            
        if len(args.metrics) == 0 and not args.per_class_eval:
            return
        
        dset_loaders = {}
        all_metric_scores = []
        epochs = len(df['train_epoch'])
        for epoch in range(epochs):
            print(f"evaluating epoch {epoch}")
            with open(os.path.join(log_path, f"epoch{epoch}_data.pkl"), mode="rb") as f:
                test_data = pickle.load(f)
            dset_loaders['source_test_data'], dset_loaders['target_test_data'] = test_data[0], test_data[1]
            args.bottleneck_dim = dset_loaders['target_test_data'][0].size(1)
            args.num_classes = dset_loaders['target_test_data'][1].size(1)
            dset_loaders = construct_metric_dataloader(args, dset_loaders, split_ratio=0.66)
            metric_scores = compute_metrics(args, dset_loaders, model=None, k_fold=args.k_fold)
            for k, v in metric_scores.items():
                logger.info(f"epoch: {epoch}/{epochs}, metric {k}: {v:.5f}")
            if args.per_class_eval:
                confmat = ConfusionMatrix(args.num_classes)
                for _, outputs, labels in tqdm(dset_loaders['metric'], total=len(dset_loaders['metric']), desc="evaluating target test accuracy"):
                    confmat.update(labels, outputs.argmax(1))
                metric_scores["class accuracy"] = " ".join([str(s) for s in np.round((confmat.compute()[1] * 100).numpy(),2)])
            all_metric_scores.append(metric_scores)

            if args.visualize:
                from common.utils.analysis import tsne
                import matplotlib.pyplot as plt
                os.makedirs(os.path.join(log_path, 'visual'), exist_ok=True)
                tSNE_filename = os.path.join(log_path, 'visual', f'TSNE_{epoch}.png')
                tsne.visualize(dset_loaders['source_test_normalized_data'][0][:5000], 
                                   dset_loaders['target_test_normalized_data'][0][:5000], tSNE_filename)
                logger.info(f"Saving t-SNE to {tSNE_filename}")
        
        new_df = pd.DataFrame(data=all_metric_scores)
        for column in new_df.columns:
            if new_df[column].dtype == object:
                try:
                    new_df[column] = new_df[column].astype(float)
                except BaseException:
                    continue
        for column in df.columns:
            if column not in new_df and "Unnamed" not in column:
                new_df[column] = df[column]

        os.remove(result_file)
        new_df.to_excel(result_file, index=False)

    if os.path.exists(args.output_path):
        for path in sorted(os.listdir(args.output_path)):
            log_path = os.path.join(args.output_path, path)
            if os.path.isdir(log_path):
                evaluate_one(log_path)
        merge_excels(args.output_path)
    else:
        evaluate_one(args.log)