import os
import pandas as pd
import wandb
import torch
import numpy as np
from torch.utils.data import DataLoader

from transforms import initialize_transform

import wilds
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.metrics.all_metrics import Accuracy
from utils import collate_list, BatchLogger, standard_group_eval, get_features

from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from numpy.linalg import norm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV, ParameterGrid
import itertools
from scipy.spatial import distance_matrix
from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter

import warnings

from sklearn.cluster import KMeans
from sklearn.utils import shuffle
from configs.hparams import *


def get_feature_path(config, split, layer=None):
    if config.eval_ckpt is None:
        FileExistsError('specify eval_ckpt to get feature directory')

    id = config.eval_ckpt
    if config.frac < 1.:
        id += f'_{config.frac}'
    if layer is not None:
        id += f'_{layer}'
    
    return f'{config.log_ckpts_dir}features_{split}_{id}.pth.tar'

def get_random_shuffle(n):
    index = np.arange(n)
    np.random.shuffle(index)
    return index

def get_numpy(tens):
    if isinstance(tens, tuple):
        return (t.cpu().numpy() if torch.is_tensor(t) else t for t in tens)
    
    return tens.cpu().numpy() if torch.is_tensor(tens) else tens


def eval_per_group(split, y_pred, y_true, metadata, grouper, metric, eval_grouping=True, remove_counts=True):
    if not eval_grouping:
        return {f'{split}/acc': (y_pred == y_true).mean()}
    
    results, results_str = standard_group_eval(
            metric,
            grouper,
            torch.tensor(y_pred), torch.tensor(y_true), metadata)
    print(split)
    print(results_str)
    if remove_counts:
        return {f'{split}/{k}': v for k, v in results.items() if 'count' not in k } 
    return {f'{split}/{k}': v for k, v in results.items()} 


def get_sklearn_model(config, x_train, y_train, mode='params', params=None, param_grid=None, clf_name='lin'):

    if clf_name in ['noreg', 'gridlr', 'lin']:
        clf = LogisticRegression
    if clf_name.endswith('sgd') or (clf_name == 'lin' and config.dataset in ['bgchallenge', 'spur_cifar10']):
        clf = SGDClassifier
    elif clf_name == 'svm':
        clf = SVC
    elif clf_name.endswith('knn'):
        clf = KNeighborsClassifier
    if mode == 'params':
        assert params is not None
        print(params)
        model = clf(random_state=config.seed, **params) if clf_name != 'knn' else clf(**params)
        model.fit(x_train, y_train)
        return model

    elif mode == 'grid':
        if clf_name.endswith('knn'):
            model = GridSearchCV(clf(), param_grid, cv=3, n_jobs=-1)
        else:
            model = GridSearchCV(clf(random_state=config.seed), param_grid, cv=3, n_jobs=-1)
        model.fit(x_train, y_train)
        return model.best_estimator_, model.best_params_
    
    KeyError(f'mode {mode} not found')

def sklearn_lin(config, model, x_train, y_train, md_train, test_dict, grouper=None, remove_counts=True):
    metric = Accuracy(prediction_fn=None)

    y_train_pred = model.predict(x_train)
    res = eval_per_group(config.train_set, y_train_pred, y_train, md_train, grouper, metric, config.eval_grouping, remove_counts=remove_counts)

    # evaluate on validation set
    for key in test_dict.keys():
        x_test, y_test, md_test = test_dict[key]
        y_test_pred = model.predict(x_test)
        test_res = eval_per_group(key, y_test_pred, y_test, md_test, grouper, metric, config.eval_grouping, remove_counts=remove_counts)
        res.update(test_res)

    return res

def save_features(config, model, loaders, grouper=None):
    feature_dict = {split: get_features(loader, model, path=get_feature_path(config, split) if config.save_features else None, recompute=config.get_features, return_md=True)
            for split, loader in loaders.items()}
    if grouper is not None and config.log_features:
        for split in loaders.keys():
            log_features(split, feature_dict[split][0], feature_dict[split][2], grouper)
    
    return feature_dict

def delete_features(config, splits):
    for split in splits:
        try:
            os.remove(get_feature_path(config, split)) 
        except:
            pass


def log_features(split, features, mds, grouper):
    df = pd.DataFrame(data=features, columns=[f'd{i}' for i in range(1, features.shape[1] + 1)])
    df['group'] = grouper.metadata_to_group(mds, return_counts=False).numpy().astype(str)
    wandb.log({f'feat_{split}': wandb.Table(dataframe=df)})

def prepare_features(config, train_split=None, test_split=None, feature_dict=None):
    test_dict = {}
    if feature_dict is None:
        # load features
        train_path = get_feature_path(config, config.train_set if train_split is None else train_split)
        x_train, y_train, md_train = torch.load(train_path)
        for sp in test_split:
            test_path = get_feature_path(config, sp)
            test_dict[sp] = torch.load(test_path)
        print('train and test path for eval_grid_search: \n', train_path, '\n', test_path)
    else:
        x_train, y_train, md_train = feature_dict[config.train_set if train_split is None else train_split]
        test_split = [test_split] if test_split is not None else [config.eval_set] + config.extra_vals
        for sp in test_split:
            test_dict[sp] = feature_dict[sp]

    # prepare features [numpyify + scale]
    x_train, y_train = get_numpy((x_train, y_train))
    if config.shuffle_train:
        perm = get_random_shuffle(x_train.shape[0])
        x_train, y_train, md_train = perm(x_train), perm(y_train), perm(md_train)
    
    if config.eval_sc: 
        sc = StandardScaler()
        x_train = sc.fit_transform(x_train)

    if test_split is None:
        test_split = [config.eval_split]

    for sp, (x_test, y_test, md_test) in test_dict.items():
        x_test, y_test = get_numpy((x_test, y_test))
        if config.eval_sc: 
            x_test = sc.transform(x_test)
        test_dict[sp] = x_test, y_test, md_test
    
    return (x_train, y_train, md_train), test_dict


def eval_grid_search(config, train_split=None, test_split=None, mode=['noreg'], grouper=None, feature_dict=None):

    if mode is None:
        mode = grid_db[config.dataset]

    (x_train, y_train, md_train), test_dict = prepare_features(config, train_split=train_split, test_split=test_split, feature_dict=feature_dict)

    res_list = []
    df = None
    acc_col = f'{config.eval_set}/{config.val_metric}'
    
    for m in mode:
        print(f'mode: {m} -------- {hparam_db[config.dataset][m]}')
        if m in ['noreg', 'svm', 'knn', 'sgd']:
            model = get_sklearn_model(config, x_train, y_train, mode='params', params=hparam_db[config.dataset][m], clf_name=m)
            print(f'mode: {m} -- ', model)
            if config.eval_norm:
                wandb.log({f'{m}_norm': norm(model.coef_)})

            res_dict = sklearn_lin(config, model, x_train, y_train, md_train, test_dict, grouper=grouper, remove_counts=False)
            if config.use_wandb:
                wandb.log({f'{m}_{key}': item for key, item in res_dict.items()})
            res_list.append({**hparam_db[config.dataset][m], **res_dict})

        elif m.startswith('grid'):
            model, best_params = get_sklearn_model(config, x_train, y_train, mode='grid', param_grid=hparam_db[config.dataset][m], clf_name=m)
            print(f'mode: {m} -- ', model)
            res_dict = sklearn_lin(config, model, x_train, y_train, md_train, test_dict, grouper=grouper)
            
            if config.use_wandb:
                wandb.log({f'{m}_{key}': item for key, item in res_dict.items()})
                wandb.log({f'{m}_{key}': item for key, item in best_params.items()})

        elif m == 'all':
            for params in hparam_db[config.dataset]['all']:
                model = get_sklearn_model(config, x_train, y_train, mode='params', params=params)
                res_dict = sklearn_lin(config, model, x_train, y_train, md_train, test_dict, grouper=grouper)
                res_list.append({**params, **res_dict})
            df = pd.DataFrame(res_list)
            lrbest_res = {f'lrbest_{key}': item for key, item in df.loc[df[acc_col].idxmax()].to_dict().items()}
            if config.use_wandb:
                wandb.log(lrbest_res)
                if config.val_metric == 'acc_wg':
                    wandb.log({f'lrbestavg_{key}': item for key, item in df.loc[df[f'{config.eval_set}/acc_avg'].idxmax()].to_dict().items()})
                wandb.log({'ds-hpsearch': wandb.Table(dataframe=df)})
        else:
            KeyError(f'mode {m} not found in db')

    return df, lrbest_res



def select_groups(features, s1, s2, train_grouper):
    x, y, g = features
    groups = train_grouper.metadata_to_group(g, return_counts=False)
    idx = (groups == s1) | (groups == s2)
    return x[idx], y[idx], g[idx], groups[idx]


def get_connectivities(config, full_dataset, train_grouper, featurizer, ssl_transform):
    get_loaders = lambda transform: {split: get_eval_loader(dataset=full_dataset.get_subset(split, frac=1., transform=transform), 
                                              loader='standard', grouper=train_grouper, 
                                              batch_size=config.batch_size, **config.loader_kwargs) 
                       for split in ['ds_train', 'ds_test']}

    feature_dict = save_features(config, featurizer, get_loaders(ssl_transform), grouper=train_grouper)
    connectivity = dict()

    for groups in list(itertools.combinations(range(4), 2)):
        s1, s2 = groups
        (x_train, y_train, md_train), test_dict = prepare_features(config, train_split=config.train_set, test_split=config.eval_set, feature_dict=feature_dict)
        x_train, y_train, md_train, g_train = select_groups(feature_dict[config.train_set], s1, s2, train_grouper)
        x_test, y_test, md_test, g_test = select_groups(feature_dict[config.eval_set], s1, s2, train_grouper)
        print(g_train, g_test)

        clf = LogisticRegression(random_state=0, C=0.316, max_iter=1000).fit(x_train, g_train)
        res_dict = sklearn_lin(config, clf, x_train, g_train, md_train, test_dict, grouper=train_grouper)
    #     connectivity[groups] = 1. - res_dict['val/acc_avg']
        connectivity[groups] = 1. - (clf.predict(x_test) == g_test.clone().detach().cpu().numpy()).mean(), len(g_test)
    
    return connectivity

def per_sample_alignment(Z):
    if torch.is_tensor(Z):
        Z = Z.detach().cpu().numpy()

    km2 = KMeans(n_clusters = 2).fit(Z)
    km4 = KMeans(n_clusters = 4).fit(Z)
    d2, d4 = km2.transform(Z), km4.transform(Z)
    unnormalized = (d2.max(axis = 1) + d4.max(axis = 1))/2
    u_max, u_min = unnormalized.max(), unnormalized.min()
    X_std = (unnormalized - u_min) / (u_max - u_min) # scale to [0, 1]

    return 1 - X_std

def alignment_loss(Z, y, a):
    '''
    Computes alignment loss as per Eq (4) of https://arxiv.org/pdf/2203.01517.pdf
    Z: (N, D) numpy array or torch tensor
    y: (N, 1) numpy array or torch tensor, categorical targets
    a: (N, 1) numpy array or torch tensor, categorical attributes
    Will return nonsensical values when not all (y, a) have samples (e.g. when invar_str = 1)
    '''
    if torch.is_tensor(Z):
        Z = Z.detach().cpu().numpy()
    if torch.is_tensor(y):
        y = y.detach().cpu().numpy()
    if torch.is_tensor(a):
        a = a.detach().cpu().numpy()
    N = Z.shape[0]
    unique_groups = np.unique(a)
    all_combinations = list(itertools.combinations(unique_groups, 2))
    losses = {}
    
    for y_i in np.unique(y):
        mask_y = (y == y_i)
        accum = 0
        for g_j, g_jp in all_combinations:
            mask_1 = mask_y & (a == g_j)
            mask_2 = mask_y & (a == g_jp)
            if mask_1.sum() > 0 and mask_2.sum() > 0:
                dist = distance_matrix(Z[mask_1], Z[mask_2]).sum()/mask_1.sum()/mask_2.sum()
                accum += dist
        losses[y_i] = accum/len(all_combinations) # average over combinations of groups
    return np.mean(list(losses.values())) # average over y's

def per_group_alignment_loss(Z, y, metadata_array, grouper):
    '''
    Computes alignment loss as per Eq (4) of https://arxiv.org/pdf/2203.01517.pdf
    Z: (N, D) numpy array or torch tensor
    y: (N, 1) numpy array or torch tensor, categorical targets
    a: (N, 1) numpy array or torch tensor, categorical attributes
    Will return nonsensical values when not all (y, a) have samples (e.g. when invar_str = 1)
    '''
    if torch.is_tensor(Z):
        Z = Z.detach().cpu()
    if torch.is_tensor(y):
        y = y.detach().cpu()
    if torch.is_tensor(metadata_array):
        metadata_array = metadata_array.detach().cpu()

    N = Z.shape[0]
    groups, group_counts = grouper.metadata_to_group(
                metadata_array,
                return_counts=True)
    unique_g = np.unique(groups)
    all_combinations_g = list(itertools.combinations(unique_g, 2))
    summary_losses = {}

    for g_j, g_jp in all_combinations_g:
        mask_1 = (groups == g_j)
        mask_2 = (groups == g_jp)
        if mask_1.sum() > 0 and mask_2.sum() > 0:
            dist = distance_matrix(Z[mask_1], Z[mask_2]).sum()/mask_1.sum()/mask_2.sum()
            summary_losses[(grouper.group_field_str(g_j), grouper.group_field_str(g_jp))] = dist
    return summary_losses


def mi(Z, Y, perc=20):
    '''
    Computes I(Z; Y) using the method from E.2.1 of https://arxiv.org/pdf/2203.01517.pdf

    Z: (N, D) numpy array or torch tensor
    Y: (N, K) numpy array or torch tensor, where K>1 denotes e.g. a group variable with multiple attributes
    '''
    if torch.is_tensor(Z):
        Z = Z.detach().cpu().numpy()
    if torch.is_tensor(Y):
        Y = Y.detach().cpu().numpy()
    N = Z.shape[0]
    
    if Y.ndim == 1:
        Y = np.expand_dims(Y, -1)
    assert N == Y.shape[0]

    # encodes multi-attribute target
    unique_labels, counts = np.unique(Y, axis = 0, return_counts = True)
    label_props = counts/N
    mapping = {tuple(i): c for c, i in enumerate(unique_labels)}
    encode_Y = np.zeros(shape = (N, 1))
    for i,c in mapping.items():
        mask = np.equal(Y, i).all(axis = 1)
        encode_Y[mask] = c

    model = GridSearchCV(estimator = LogisticRegression(random_state = 42),
                         param_grid = {
                            #  'C': 10**np.linspace(-5, 1, 20),
                            'C': 10**np.linspace(-5, 1, perc),
                             'multi_class': ['ovr', 'multinomial']
                         }, n_jobs = -1, cv = 3, scoring = 'roc_auc_ovr').fit(Z, encode_Y.ravel())
    
    pred_proba = model.predict_proba(Z)
    
    return (pred_proba * np.log2(pred_proba/label_props)).sum()/N


def infer_group_eval(train_distr, train_groups, val_distr, val_groups, val_m, val_set, log_dir, use_wandb):
    from sklearn.linear_model import LogisticRegression, SGDClassifier, LassoLars
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.svm import SVC
    from sklearn.model_selection import GridSearchCV

    lr_grid = {'C': [1., .7, .3, .1, 0.07, 0.03, 0.01, 10, 100],
                'penalty': ['l1'], 'solver': ['liblinear']}
    rbf_svm_grid = {
                        'C': [0.01, 0.1, 1., 10, 100], 
                        'gamma': ['auto', 'scale', 1e-3, 1e-4], 
                    }
    svm_grid = {'C': [0.01, 0.1, 1., 10, 100], }
    knn_grid = {'n_neighbors': [3, 5, 10, 20]}
    
    clfs = {
            'lin': GridSearchCV(LogisticRegression(random_state=0), lr_grid, cv=5),
            'knn': GridSearchCV(KNeighborsClassifier(), knn_grid, cv=5),
            'rbf_svm': GridSearchCV(SVC(kernel='rbf'), rbf_svm_grid, cv=5),
            'svm': GridSearchCV(SVC(kernel='linear'), svm_grid, cv=5), 
            }
    loggers = {c: BatchLogger(f'{log_dir}/pred_group_{c}.csv', mode='w', use_wandb=(use_wandb)) for c in clfs.keys()}

    def eval_sklearn_clf(clf_key, x_train, y_train, x_test, y_test, test_metadata):
        clf = clfs[clf_key]
        clf.fit(x_train, y_train)
        g_pred = torch.tensor(clf.best_estimator_.predict(x_test))
        results, results_str = val_set.eval(g_pred, y_test, test_metadata)
        print(f'{clf_key} best parameters: ', clf.best_params_)
        print(f'sklearn {clf_key} result ----------------')
        print(results_str)
        if use_wandb:
            wandb.log(clf.best_params_)
        loggers[clf_key].log(results)
        return results
    
    for c in clfs.keys():
        eval_sklearn_clf(c, train_distr, train_groups, val_distr, val_groups, val_m)


def evaluate_clustering(train_distr, train_m, grouper, log_dir, use_wandb):

    def eval_sklearn_clust(c, feats, g):
        from sklearn.metrics.cluster import adjusted_rand_score as ARS
        from sklearn.metrics import normalized_mutual_info_score as NMI
        from sklearn.metrics import homogeneity_score, completeness_score, v_measure_score, fowlkes_mallows_score
        from sklearn.metrics import davies_bouldin_score, calinski_harabasz_score, silhouette_score
        from sklearn.metrics.cluster import contingency_matrix

        results = {
            'ARS': ARS(g, c), 'NMI': NMI(g, c),
            'hom': homogeneity_score(g, c), 'cmp': completeness_score(g, c), 'v_ms': v_measure_score(g, c),
            'fow': fowlkes_mallows_score(g, c),
            'sil': silhouette_score(feats, c, metric='euclidean'), 'dav': davies_bouldin_score(feats, c), 'cal': calinski_harabasz_score(feats, c),
            'n': len(np.unique(c)),
        }
        
        for key, value in results.items():
            print(f'{key}: ', value)
        print('contingency_matrix:')
        print(contingency_matrix(g, c))
        print()
        print(contingency_matrix(g, c).sum(1))
        print(contingency_matrix(g, c).sum(0))
        return results
    
    groups, counts = grouper.metadata_to_group(train_m, return_counts=True)
    n_clusters = len(counts)

    from sklearn.cluster import KMeans, SpectralClustering, AgglomerativeClustering, AffinityPropagation, MeanShift, DBSCAN, OPTICS
    clust_methods = {
            'kmeans': KMeans(n_clusters=n_clusters, random_state=0),
            'agg': AgglomerativeClustering(linkage='ward', n_clusters=n_clusters),
            'spec': SpectralClustering(assign_labels='kmeans', n_clusters=n_clusters, random_state=0),
            'meansh': MeanShift()
            }
    loggers = {c: BatchLogger(f'{log_dir}/{c}.csv', mode='w', use_wandb=(use_wandb)) for c in clust_methods.keys()}

    
    for c_name in clust_methods.keys():
        print(f'{c_name} result -------------------------------')
        clust = clust_methods[c_name]
        clust.fit(train_distr)
        c = clust.labels_
        results = eval_sklearn_clust(c, train_distr, groups)
        loggers[c_name].log(results)
    

def quick_eval(self, config, datasets):
        from utils import get_features, eval_sklearn_clf
        self.train_dataloader = datasets['train']['loader']
        self.val_dataloader = datasets[config.eval_set]['loader']
        self.train_set = datasets['train']['dataset']
        self.val_set = datasets[config.eval_set]['dataset']

        x_train, y_train, train_md = get_features(self.train_dataloader, self.model)
        x_val, y_val, val_md = get_features(self.val_dataloader, self.model)

        if self.eval_sc:
            sc = StandardScaler()
            x_train = sc.fit_transform(x_train)
            x_val = sc.transform(x_val)
        
        keys = []
        res = {}
        if self.eval_lin:
            keys.append('lin')
        if self.eval_knn:
            keys.append('knn')

        keys = ['lin', 'knn']

        for key in keys:
            res[key] = eval_sklearn_clf(key, x_train, y_train, train_md, x_val, y_val, val_md, self.val_set, return_pred=True)
        
        return res


class Validation(object):
    def __init__(self, args, config, model, rep_dim=None, grouper=None, train_features=None, val_features=None, eval_set='val', train_set='train', extra_evals=[]):
        self.model = model
        self.device = torch.device('cuda' if next(model.parameters()).is_cuda else 'cpu')
        self.args = args
        self.rep_dim = rep_dim
        self.K = config.knn_k
        self.linreg_c = config.linreg_c
        self.eval_lin = config.eval_lin
        self.eval_knn = config.eval_knn
        self.eval_mi = config.eval_mi
        self.eval_layer_wise = config.eval_layer_wise
        self.eval_layer_wise_lr = config.eval_layer_wise_lr
        self.eval_layer_wise_knn = config.eval_layer_wise_knn
        self.eval_layers_num = config.eval_layers_num
        self.eval_train = config.eval_train
        self.eval_spur = config.eval_spur
        self.eval_grouping = config.eval_grouping  # group logging for evaluations
        self.uniform_over_groups = config.uniform_over_groups
        self.eval_group_alignment = config.eval_group_alignment
        self.eval_gridsearch = config.eval_gridsearch
        self.eval_sc = config.eval_sc
        assert self.eval_knn or self.eval_lin
        self.best_acc = 0
        self.use_wandb = config.use_wandb
        self.grouper = grouper
        self.infer_group = config.infer_group
        self.shuffle_train = config.shuffle_train
        print('shuffle train for model evaluation: ', self.shuffle_train)

        self.config = config
        self.eval_split = eval_set
        self.train_split = train_set

        from utils import get_model_module
        self.dtype = get_model_module(self.model).conv1.weight.dtype

        # mode='a' if wandb.run.resumed else 'w'
        mode='w'
        self.mode = mode
        self.log_dir = config.log_dir
        self.seed = config.seed
        self.use_wandb = config.use_wandb
        self.loggers = {}
        logger_path = os.path.join(config.log_dir, f'metrics.csv')
        self.loggers['metrics'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
        
        if self.eval_knn:
            logger_path = os.path.join(config.log_dir, f'knn.csv')
            # mode='w'
            self.loggers['knn'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
        if self.eval_lin:
            logger_path = os.path.join(config.log_dir, f'lin.csv')
            self.loggers['lin'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
            if self.eval_spur: 
                logger_path = os.path.join(config.log_dir, f'spur_lin.csv')
                self.loggers['spur_lin'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
            if self.eval_train:
                logger_path = os.path.join(config.log_dir, f'train_lin.csv')
                self.loggers['train_lin'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
                if self.eval_spur: 
                    logger_path = os.path.join(config.log_dir, f'spur_train_lin.csv')
                    self.loggers['spur_train_lin'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
            if self.eval_mi and self.eval_layer_wise:
                logger_path = os.path.join(config.log_dir, f'mi.csv')
                self.loggers['mi'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
                logger_path = os.path.join(config.log_dir, f'align.csv')
                self.loggers['align'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
            if self.eval_layer_wise_lr:
                logger_path = os.path.join(config.log_dir, f'lr_lw.csv')
                self.loggers['lr_lw'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))
            if self.eval_layer_wise_knn:
                self.lw_knn_preds = None
            if self.eval_group_alignment:
                logger_path = os.path.join(config.log_dir, f'group_align.csv')
                self.loggers['group_align'] = BatchLogger(logger_path, mode=mode, use_wandb=(config.use_wandb))

        
        full_dataset = wilds.get_dataset(
            dataset=config.dataset,
            root_dir='',
            split_scheme=config.split_scheme,
            **config.dataset_kwargs)


        base_transforms = initialize_transform(config.transform, config, full_dataset, is_training=False)
        train_transforms = initialize_transform(config.transform, config, full_dataset, is_training=True)
        if config.dataset in ['spur_cifar10', 'cmnist', 'cifar10']: # and 'ssl' in config.wandb_name:
            train_transforms = base_transforms
        train_frac = config.frac if config.dataset not in ['spur_cifar10', 'celebA'] else 1.
        train_dataset = full_dataset.get_subset(
                self.train_split,
                frac=train_frac,
                transform=train_transforms)
        val_dataset = full_dataset.get_subset(
                self.eval_split,
                frac=config.frac,
                transform=base_transforms)
        self.val_set = val_dataset
        self.extra_vals = dict()
        
        self.train_dataloader = DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=self.shuffle_train,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        )
                                            #    drop_last=True)
        self.val_dataloader = DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.num_workers,
                                         pin_memory=True,
                                        #  drop_last=True)
                                        )
        
        for val in extra_evals:
            set = full_dataset.get_subset(val, frac=config.frac, transform=base_transforms)
            dataloader = DataLoader(set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
            self.extra_vals[val] = {'dataset': set, 'loader': dataloader}

        print('Validation datasets: ')
        print(f'    {self.config.train_set} : n = {len(train_dataset):.0f}')
        print(f'    {self.config.eval_set}  : n = {len(val_dataset):.0f}')
        for val in extra_evals:
            sp = len(self.extra_vals[val]['dataset'])
            print(f'    {val}  : n = {sp:.0f}')

        index = np.arange(len(train_dataset))
        np.random.shuffle(index)
        self.train_index = index
        self.y_train = self.train_dataloader.dataset.y_array.clone().detach()
        self.g_train = self.train_dataloader.dataset.metadata_array.clone().detach()
        if self.shuffle_train:
            self.y_train = self.y_train[self.train_index]
            self.g_train = self.g_train[self.train_index, :]

        
        self.y_val = self.val_dataloader.dataset.y_array.clone().detach()
        self.g_val = self.val_dataloader.dataset.metadata_array.clone().detach()


    def quick_eval(self, config=None, datasets=None):
        from utils import get_features, eval_sklearn_clf
        from sklearn.preprocessing import StandardScaler
        if config is None:
            config = self.config
        if datasets is not None:
            self.train_dataloader = datasets['train']['loader']
            self.val_dataloader = datasets[config.eval_set]['loader']
            self.train_set = datasets['train']['dataset']
            self.val_set = datasets[config.eval_set]['dataset']

        x_train, y_train, train_md = get_features(self.train_dataloader, self.model)
        x_val, y_val, val_md = get_features(self.val_dataloader, self.model)

        if self.eval_sc:
            sc = StandardScaler()
            x_train = sc.fit_transform(x_train)
            x_val = sc.transform(x_val)
        
        keys = []
        res = {}
        if self.eval_lin:
            keys.append('lin')
        if self.eval_knn:
            keys.append('knn')

        keys = ['lin', 'knn']

        for key in keys:
            res[key] = eval_sklearn_clf(key, x_train, y_train, train_md, x_val, y_val, val_md, self.val_set)
        
        return res, (x_train, y_train, train_md), (x_val, y_val, val_md)

    def _linear_cls(self, x_train, x_val):
        if torch.is_tensor(x_train):
            x_train = x_train.cpu().numpy()
        if torch.is_tensor(x_val):
            x_val = x_val.cpu().numpy()
        clf = LogisticRegression

        if self.eval_gridsearch:
            lr_grid = {'C': 
                [1., .7, .3, .1, 10],
                'penalty': ['l1'], 'solver': ['liblinear']}
        elif self.linreg_c is not None:
            lr_grid = {'C': [self.linreg_c], 'penalty': ['l1'], 'solver': ['liblinear']}
        else:
            lr_grid = {'solver': ['lbfgs'], 'penalty': ['none'], 'max_iter': [1000]}
        if self.config.dataset == 'bgchallenge':
            clf = SGDClassifier
            lr_grid = {'penalty': ['l1']}
        model = GridSearchCV(clf(random_state=self.seed), lr_grid, cv=5, n_jobs=-1)

        model.fit(x_train, self.y_train.cpu())
        pred = torch.tensor(model.best_estimator_.predict(x_val))
        acc_val = (pred == self.y_val).float().mean().item()
        print(f'lin best parameters: ', model.best_params_)
        res = {'lin': (acc_val, pred)}

        if self.eval_train:
            pred_train = torch.tensor(model.predict(x_train))
            acc_train = (pred_train == self.y_train).float().mean().item()
            res['train_lin'] = (acc_train, pred_train)
        
        if not self.eval_spur:
            return res
        
        # take the spuriously correlated attribute as the label
        train_labels = self.g_train[:, 0]
        val_labels = self.g_val[:, 0]
        model.fit(x_train, train_labels.cpu())
        pred = torch.tensor(model.predict(x_val))
        acc_val = (pred == val_labels).float().mean().item()
        res['spur_lin'] = (acc_val, pred)
        if self.eval_train:
            pred_train = torch.tensor(model.predict(x_train))
            acc_train = (pred_train == train_labels).float().mean().item()

        return res
    
    def _knn_cls(self, x_train, x_val):
        if torch.is_tensor(x_train):
            x_train = x_train.cpu().numpy()
        if torch.is_tensor(x_val):
            x_val = x_val.cpu().numpy()
        
        from sklearn.neighbors import KNeighborsClassifier
        if self.eval_gridsearch:
            knn_grid = {'n_neighbors': [3, 5, 7, 10, 15, 20]}
        else:
            knn_grid = {'n_neighbors': [self.K]}
        model = GridSearchCV(KNeighborsClassifier(n_neighbors=self.K), knn_grid, cv=5, n_jobs=-1)
        model.fit(x_train, self.y_train.cpu())
        pred = torch.tensor(model.best_estimator_.predict(x_val))
        acc_val = (pred == self.y_val).float().mean().item()
        print(f'knn best parameters: ', model.best_params_)
        return {'knn': (acc_val, pred)}

    def _rep_metrics(self, x_val):
        mi_res = {}
        mi_res['mi_y'] = mi(x_val, self.y_val)
        
        if self.g_val.ndim == 2 and self.g_val.shape[1] > 1:
            grp_without_labels = self.g_val[:, ~(self.g_val == self.y_val.unsqueeze(-1)).all(axis = 0)]
        else:
            grp_without_labels = self.g_val
        mi_res['mi_g'] = mi(x_val, grp_without_labels)
        
        res = {}
        res['mi'] = mi_res
        res['alignment_loss'] = alignment_loss(x_val, self.y_val, grp_without_labels.squeeze())
        
        return res

    def _rep_layer_wise(self, conv_layers, epoch):

        print('conv_layers ', conv_layers)
        if self.g_val.ndim == 2 and self.g_val.shape[1] > 1:
            grp_without_labels = self.g_val[:, ~(self.g_val == self.y_val.unsqueeze(-1)).all(axis = 0)]
        else:
            grp_without_labels = self.g_val
        
        
        mi_res = {'epoch': epoch}
        align_res = {'epoch': epoch}
        lr_res = {'epoch': epoch}
        # knn_res = {'epoch': epoch}
        knn_res = {}

        return_layers = dict(zip(conv_layers, conv_layers))
        mid_getter = MidGetter(self.model, return_layers=return_layers, keep_output=True)
        
        # train ===========================================================
        if self.eval_layer_wise or self.eval_layer_wise_lr or self.eval_layer_wise_knn:
            inter_rep_train = self.get_features(set='train', mid_getter=mid_getter, return_md=False)
        # val ===========================================================
        inter_rep = self.get_features(set='eval', mid_getter=mid_getter, return_md=False)
        
        perc = 7
        print('computing layer wise metrics...')
        if self.eval_layer_wise:
            for layer in inter_rep.keys():
                x_val = inter_rep[layer]
                mi_res[f'y_{layer}'] = mi(x_val, self.y_val, perc=perc)
                mi_res[f'g_{layer}'] = mi(x_val, grp_without_labels, perc=perc)
                align_res[f'{layer}'] = alignment_loss(x_val, self.y_val, grp_without_labels.squeeze())
        
        
        if self.eval_layer_wise_lr:   
            for layer in inter_rep.keys():
                x_val = inter_rep[layer]
                x_train = inter_rep_train[layer]
                print(x_val.shape)
                print(x_train.shape)
                res = self._linear_cls(x_train, x_val)
                lr_res[f'lin_{layer}'] = res

        def get_layer_pred_df(preds, layer, epoch):
            preds = preds.cpu().numpy()
            layer_pred = pd.DataFrame(zip(range(len(preds)), preds), columns=['idx', 'pred'])
            layer_pred['layer'] = layer
            layer_pred['epoch'] = epoch
            return layer_pred
        
        def update_knn_preds(preds):
            layer_pred_df = get_layer_pred_df(preds, layer, epoch)
            if self.lw_knn_preds is not None:
                self.lw_knn_preds = pd.concat([self.lw_knn_preds, layer_pred_df])
            else:
                self.lw_knn_preds = layer_pred_df
        
        def save_knn_preds(preds):
            layer_pred_df = get_layer_pred_df(preds, layer, epoch)
            path = f'{self.log_dir}/lw_knn_preds_seed:{self.seed}.csv'
            layer_pred_df.to_csv(path, mode='a')

        if self.eval_layer_wise_knn:
            for layer in inter_rep.keys():
                x_val = inter_rep[layer]
                x_train = inter_rep_train[layer]
                print(x_train.shape, x_val.shape)
                top1_knn, pred_knn = self._knn_cls(x_train, x_val)['knn']
                # save preds of this layer and this epoch
                update_knn_preds(pred_knn)
                knn_res[f'knn_{layer}'] = top1_knn
        return mi_res, align_res, lr_res, knn_res
    
    def get_shuffled(self, arr, set):
        if set != 'train' or not self.shuffle_train:
            return arr
        
        if not isinstance(arr, tuple):
            return arr[self.train_index, :]
        
        arr_list = arr
        return tuple([arr[self.train_index, :] for arr in arr_list])
    
    def get_features(self, set='train', dataloader=None, mid_getter=None, return_md=False, epoch=-1):
        save_features_flag = self.config.save_features
        return_res = lambda res: res if return_md else res[0]

        if set == 'train':
            dataloader = self.train_dataloader
            split = self.train_split
        elif set == 'eval':
            dataloader = self.val_dataloader
            split = self.eval_split
        else:
            assert dataloader is not None
        

        print('len of dataloader: ', len(dataloader.dataset))
        feature_path = get_feature_path(self.config, split)
        if not self.config.get_features and mid_getter is None and self.config.eval_ckpt is not None:
            print(f'loading featuers from {feature_path}')
            return return_res(self.get_shuffled(torch.load(feature_path), set))

        def save_features(feats, labels, metadata, all_feature, all_label, all_metadata, return_md=True):
            if all_feature is None:
                all_feature = feats.float().cpu()
                if return_md:
                    all_metadata = metadata.float().cpu()
                    all_label = labels.float()
            else:
                all_feature = torch.cat((all_feature,feats.float().cpu()),0)
                if return_md:
                    all_metadata = torch.cat((all_metadata, metadata.float().cpu()), 0)
                    all_label = torch.cat((all_label, labels.float()), 0)
            if not return_md:
                return all_feature, None, None
            return all_feature, all_label, all_metadata
        
        if mid_getter is None:
            assert not return_md
        
        print('getting features...')
        start_test = True
        inter_rep_train = {}
        all_label, all_metadata = None, None
        with torch.no_grad():
            iter_test = iter(dataloader)
            for i in range(len(dataloader)):
                data = next(iter_test)
                inputs = data[0]
                labels = data[1]
                metadata = data[2]
                inputs = inputs.cuda()
                if mid_getter is not None:
                    mid_outputs, _ = mid_getter(inputs.to(self.device))  
                    if start_test:
                        for layer in mid_outputs.keys():
                            inter_rep_train[layer] = None
                        start_test = False
                    for layer in mid_outputs.keys():
                        feats = mid_outputs[layer]
                        print('feats', feats.shape)
                        feats = torch.flatten(feats, 1)
                        print('feats ', feats.shape)
                        inter_rep_train[layer] = save_features(feats, None, None, inter_rep_train[layer], None, None, return_md=False)[0]
                else:  
                    if start_test:
                        all_feature, all_label, all_metadata = None, None, None
                        start_test = False
                    feats = self.model(inputs.to(self.device))
                    all_feature, all_label, all_metadata = save_features(feats, labels, metadata, all_feature, all_label, all_metadata)
                    
        if mid_getter is not None:
            return self.get_shuffled(inter_rep_train, set)
        
        res = all_feature, all_label, all_metadata
    
        if save_features_flag:
            print(f'saving featuers to {feature_path}')
            torch.save(res, feature_path)
        
        return return_res(self.get_shuffled(res, set))
        
    
    def _eval(self): # feature extraction and knn if needed
        """Extract features from validation split and search on train split features."""
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
        
        if self.rep_dim is None:
            self.rep_dim = self.args.feat_dim

        self.model.eval()
        if str(self.device) == 'cuda':
            torch.cuda.empty_cache()

        train_features = self.get_features(set='train', return_md=False)
        val_features = self.get_features(set='eval', return_md=False)
        res = {}
        from sklearn.preprocessing import StandardScaler
        if self.eval_sc:
            sc = StandardScaler()
            x_train = sc.fit_transform(train_features)
            x_val = sc.transform(val_features)
        else:
            x_train, x_val = train_features, val_features
        
        if self.eval_lin:
            res = self._linear_cls(x_train, x_val)
        if self.eval_knn:
            res = {**res, **self._knn_cls(x_train, x_val)}
        
        if self.eval_mi:
            mi_res = self._rep_metrics(val_features)
            res = {**res, **mi_res}
        if self.eval_group_alignment:
            grp_algn_res = per_group_alignment_loss(val_features, self.y_val, self.g_val, self.grouper)
            res['group_align'] = grp_algn_res
        
        if self.infer_group:
            train_groups = self.grouper.metadata_to_group(self.g_train, return_counts=False)
            val_groups = self.grouper.metadata_to_group(self.g_val, return_counts=False)
            infer_group_eval(x_train, train_groups, x_val, val_groups, self.g_val, self.val_dataloader.dataset, self.log_dir, self.use_wandb)
            evaluate_clustering(x_train, self.g_train, self.grouper, self.log_dir, self.use_wandb)
        
        return res
    
    def group_eval(self, pred, epoch, logging_mode, train=False):
        if 'spur' in logging_mode:
            print('logging mode ', logging_mode)
            train_labels = self.g_train[:, 0]
            val_labels = self.g_val[:, 0]
            if self.g_val.ndim == 2 and self.g_val.shape[1] > 1:
                grp_without_labels = self.g_val[:, ~(self.g_val == self.y_val.unsqueeze(-1)).all(axis = 0)]
            else:
                grp_without_labels = self.g_val
        else:
            train_labels = self.y_train
            val_labels = self.y_val

        if not train:
            results, results_str = self.val_dataloader.dataset.eval(
                    pred.cpu(),
                    val_labels.cpu(),
                    self.g_val.cpu())
        else:
            results, results_str = self.train_dataloader.dataset.eval(
                    pred.cpu(),
                    train_labels.cpu(),
                    self.g_train.cpu())
        print(results_str)
        results['epoch'] = epoch
        self.loggers[logging_mode].log(results)
        return results
    
    def close_loggers(self):
        for _, l in self.loggers:
            l.close()

    def eval(self, epoch):
        is_best = False
        print("Validating...")
        res = self._eval()
        for key, item in res.items():
            if key in ['knn', 'lin', 'train_lin', 'spur_lin', 'spur_train_lin'] or key.startswith('lr_lw'):
                acc, pred = item
                print('Top1 {}: {}'.format(key, acc))
                if self.use_wandb:
                    wandb.define_metric(f"{key}/acc*", step_metric=f"{key}/epoch")

                if self.eval_grouping:
                    self.group_eval(pred, epoch, key, train=('train' in key))
                if 'train' not in key and 'spur' not in key and acc > self.best_acc: 
                    is_best = True
        
        if not self.eval_grouping:
            log_dict = {f'{key}/acc': item[0] for key, item in res.items()}
            log_dict[f'{key}/epoch'] = epoch
            wandb.log(log_dict)

        if self.eval_mi:
            print('I(Z; Y): {}'.format(res['mi']['mi_y']))
            print('I(Z; G): {}'.format(res['mi']['mi_g']))
            print('alignment: {}'.format(res['alignment_loss']))
            metric_results = {**res['mi'], 'alignment_loss': res['alignment_loss'], 'epoch': epoch}
            if self.eval_group_alignment:
                self.loggers['group_align'].log(res['group_align'])
            if self.use_wandb:
                wandb.define_metric("metrics/*", step_metric="metrics/epoch")
            self.loggers['metrics'].log(metric_results)
            
        if self.eval_layer_wise or self.eval_layer_wise_lr or self.eval_layer_wise_knn:
            mi_res, align_res = {}, {}
            lr_res = {}
            knn_res = {}

            print('layer wise ----------------------------------------------')
            all_conv_layers = [name for name, module in self.model.named_modules() 
                                    if isinstance(module, torch.nn.Conv2d) and 'conv' in name]
            
            if hasattr(self.model, 'fc'):
                all_conv_layers = all_conv_layers[1:] + ['fc']
            
            if self.eval_layers_num != -1:
                all_conv_layers = all_conv_layers[-self.eval_layers_num:]
            print('Evaluating on: ')
            print(all_conv_layers)
            n = 1
            c_list = [all_conv_layers[i: i + n] for i in range(0, len(all_conv_layers), n)]
            for conv_layers in c_list:
                torch.cuda.empty_cache()
                mi_r, align_r, lr_r, knn_r = self._rep_layer_wise(conv_layers, epoch)    
                if self.eval_layer_wise:
                    mi_res.update(mi_r)
                    align_res.update(align_r)
                if self.eval_layer_wise_lr:
                    lr_res.update(lr_r)
                if self.eval_layer_wise_knn:
                    knn_res.update(knn_r)

            print('-----------------------------------')
            if self.eval_layer_wise:
                print('MI accross layers: ')
                for k, e in sorted(mi_res.items()):
                    print(f'{k}: {e}')
                print('\nAlign accross layers: ')
                for k, e in align_res.items():
                    print(f'{k}: {e}')
                
                self.loggers['mi'].log(mi_res)
                self.loggers['align'].log(align_res)
            if self.eval_layer_wise_lr:
                print('LR accross layers: ')
                for k, e in sorted(lr_res.items()):
                    print(f'{k}: {e}')
                for key, item in lr_res.items():
                        acc, pred = item['lin']
                        print('Top1 {}: {}'.format(key, acc))
                        if self.use_wandb:
                            wandb.define_metric(f"{key}/acc*", step_metric=f"{key}/epoch") # , summary="max", 
                            wandb.log({f'val_{key}_acc': acc})
                        if self.eval_grouping:
                            self.group_eval(pred, epoch, key, train=('train' in key))
            if self.eval_layer_wise_knn:
                df = self.lw_knn_preds
                for layer in all_conv_layers:
                    pred = torch.tensor(df[df['layer'] == layer]['pred'].values)
                    if len(pred) == 0:
                        continue
                    if self.eval_grouping:
                        res = self.group_eval(pred, epoch, key, train=('train' in key))
                        if self.use_wandb: 
                            print(f'WG {key}: ', )
                            print('knn_lw_wg', res['acc_wg'])
                            print('knn_lw_avg', res['acc_avg'])
                            wandb.log({'knn_lw_wg': res['acc_wg'], 'knn_lw_avg': res['acc_avg']})
                    else:
                        print('knn_lw_avg', knn_res[f'knn_{layer}'])
                        if self.use_wandb: 
                            wandb.log({'knn_lw_avg': knn_res[f'knn_{layer}']})

                def get_pd(df):  # compute prediction depth based on a given dataframe
                    v = df.sort_values(by=['layer'])['pred']
                    flips = np.where(v[:-1] != v[1:])[0]
                    pdepth = 0 if len(flips) == 0 else flips[-1] + 1
                    return pdepth
                pdepths = df.groupby(['idx', 'epoch']).apply(get_pd).reset_index(name='pd')
                pdepths['group'] = self.grouper.metadata_to_group(self.g_val, return_counts=False).cpu().numpy()
                
                path = f'{self.log_dir}/pdepth_{self.seed}.csv'
                pdepths.to_csv(path, mode='w')
                self.pdepths = pdepths
                self.lw_knn_preds = None
                table = wandb.Table(dataframe=pdepths)
                if self.use_wandb:
                    wandb.log({'pred-depth': wandb.Table(dataframe=pdepths),
                                'pd-hist': pdepths.pivot_table(index=['group', 'pd'], values='idx', aggfunc=len).reset_index()})
        return is_best, res
