import os
import pandas as pd
import numpy as np
import argparse
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.celebA_dataset import CelebADataset
from wilds.datasets.coco_places_dataset import COCOonPlacesDataset
from munkres import Munkres
from examples.utils.print_logger import get_logger
LOGGER = get_logger(__name__, level="DEBUG")

IGNORE_LABEL = 255
DATASET_DICT = {
    'waterbirds': {
        'env_map': ['land', 'water'],
        'class_map': ['landbird', 'waterbird']
        },
    'celebA': {
        'env_map': ['female', 'male'],
        'class_map': ['dark', 'blond']
    },
    'coco_on_colours': {
        'env_map': ['bias', 'non-bias', 'sys shift', 'ood'],
        'class_map': ['boat', 'airplane', 'truck', 'dog', 'zebra', 'horse', 'bird', 'train', 'bus']
    },
    'coco_on_places': {
        'env_map': ['bias', 'non-bias', 'sys shift', 'ood'],
        'class_map': ['boat', 'airplane', 'truck', 'dog', 'zebra', 'horse', 'bird', 'train', 'bus']
    },
}


class ClusterEvaluator:
    def __init__(self, dataset_name, true_classes, true_env_labels, pred_env_labels, ignore_labels=False):
        self.dataset_name = dataset_name
        self.true_classes = true_classes.numpy()
        self.true_env_labels = true_env_labels.numpy()
        self.pred_env_labels = pred_env_labels

        # Filter entries
        if ignore_labels:
            LOGGER.info('Ignoring test images')
            where_ignore = (self.pred_env_labels == IGNORE_LABEL)
            self.true_classes = self.true_classes[~where_ignore]
            self.pred_env_labels = self.pred_env_labels[~where_ignore]
            self.true_env_labels = self.true_env_labels[~where_ignore]

    def compute_cost_matrix(self):
        u_pred_env_labels = np.unique(self.pred_env_labels)
        u_true_env_labels = np.unique(self.true_env_labels)
        l1 = u_pred_env_labels.size
        l2 = u_true_env_labels.size
        assert(l1 == l2 and np.all(u_pred_env_labels == u_true_env_labels))

        m = np.ones([l1, l2])
        for i in range(l1):
            it_i = np.where(self.pred_env_labels == u_pred_env_labels[i])
            for j in range(l2):
                it_j = np.where(self.true_env_labels == u_true_env_labels[j])
                m_ij = np.intersect1d(it_j, it_i)
                m[i,j] =  -m_ij.size
        return m

    def get_cluster_mapping(self):
        cost_matrix = self.compute_cost_matrix()
        m = Munkres()
        indexes = m.compute(cost_matrix)
        self.mapper = {old: new for (old, new) in indexes}

        return self.mapper

    def compute_clustering_metrics(self):
        self.get_cluster_mapping()
        env_ids, counts = np.unique(self.true_env_labels, return_counts=True)
        num_classes = np.unique(self.true_classes)
        per_group_acc_dict = {
            env_ix: 0.0 for env_ix in env_ids
        }
        per_group_class_acc_dict = {
            env_ix: {c_ix: 0.0 for c_ix in num_classes} for env_ix in env_ids
        }
        per_group_class_count_dict = {
            env_ix: {c_ix: 0.0 for c_ix in num_classes} for env_ix in env_ids
        }
        per_group_count = {
            env_ids[i]: counts[i] for i in range(len(env_ids))
        }
        global_acc = 0.0
        # constrain img_ix to train
        for img_ix in range(len(self.pred_env_labels)):
            up_pred = self.mapper[self.pred_env_labels[img_ix]]
            true_label = self.true_env_labels[img_ix]
            true_class = self.true_classes[img_ix]
            if up_pred == true_label:
                global_acc += 1.
                per_group_acc_dict[up_pred] += 1.
                per_group_class_acc_dict[up_pred][true_class] += 1.
            per_group_class_count_dict[true_label][true_class] += 1.

        global_acc /= len(self.true_env_labels)
        per_group_acc_dict = {
            u: k / per_group_count[u] for u,k in per_group_acc_dict.items()
        }

        print(f'##### Global accuracy: {global_acc} #####')
        print('##### Accuracy per environment #####')
        for u,k in per_group_acc_dict.items():
            print(f'Environment {DATASET_DICT[self.dataset_name]["env_map"][u]}: {k}')
        print('##### Accuracy per environment & per class #####')
        for u, sub_d in per_group_class_acc_dict.items():
            for c, v in sub_d.items():
                reg_acc = v / per_group_class_count_dict[u][c]
                print(f'Environment {DATASET_DICT[self.dataset_name]["env_map"][u]}'
                     f' Class {DATASET_DICT[self.dataset_name]["class_map"][c]}: {reg_acc}')


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--root_dir', required=True,
                        help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')
    parser.add_argument("--dataset", type=str, default='waterbirds', help='')
    parser.add_argument('--pred_file_path', required=True, help='Path of the pred csv file. ')
    parser.add_argument("--pred_val", action="store_true", default=False, help="Pred environments also on val set")

    parser.add_argument("--ignore_labels", action="store_true", default=False, help="Ignoring test labels")
    config = parser.parse_args()

    if config.dataset == 'waterbirds':
        LOGGER.info('Loading Waterbirds dataset')
        dataset = WaterbirdsDataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'celebA':
        LOGGER.info('Loading CelebA dataset')
        dataset = CelebADataset(root_dir=config.root_dir, get_img_idx=True)
    elif config.dataset == 'coco_on_places':
        LOGGER.info('Loading COCO-on-Places dataset')
        dataset = COCOonPlacesDataset(root_dir=config.root_dir, get_img_idx=True)
    else:
        raise ValueError(f"Dataset {config.dataset} not recognized")

    pred_env = pd.read_csv(config.pred_file_path)
    pred_env = pred_env['env'].values

    # Extract train and val indices
    env_root_dir = os.path.dirname(config.pred_file_path)
    ind_train = np.loadtxt(os.path.join(env_root_dir,'train_indices.txt')).astype('int')
    ind_train_orig = np.arange(len(dataset))[dataset.split_array == 0]
    assert np.all(ind_train == ind_train_orig)
    ind_split = [ind_train]

    if config.pred_val:
        ind_val = np.loadtxt(os.path.join(env_root_dir, 'val_indices.txt')).astype('int')
        ind_val_orig = np.arange(len(dataset))[dataset.split_array == 1]
        assert np.all(ind_val == ind_val_orig)
        ind_split.append(ind_val)

    for i, ind in enumerate(ind_split):
        print('-' * 10)
        print('Train' if i == 0 else 'Val')
        print('-' * 10)
        c_eval = ClusterEvaluator(
            dataset_name = config.dataset,
            true_classes = dataset.metadata_array[ind,1],
            true_env_labels = dataset.metadata_array[ind,0],
            pred_env_labels = pred_env[ind],
            ignore_labels = config.ignore_labels,
        )
        c_eval.compute_clustering_metrics()
    print('*' * 20)


if __name__ == "__main__":
    main()
