import copy
import datetime
from termcolor import colored
from tqdm import tqdm
import torch
import numpy as np
from sklearn.cluster import k_means
from collections import Counter
import training.utils as utils
import torch.nn.functional as F
from data_utils import EnvSampler, is_textdata
from torch.utils.data import DataLoader
from model_utils import get_model
import pickle
from sklearn import metrics


def visualize_cluster(data, env_id, partition_res):
    all_att = data.get_all_att(env_id)
    all_y = data.get_all_y(env_id)

    res_dict = {
        'y=0': [],
        'y=1': [],
    }

    for partition in partition_res:
        res = {}

        for att_idx, att_name in enumerate(data.train_data.attr_names):
            if att_idx == data.label_idx or att_idx == data.cor_idx:
                continue

            # go through the partitions
            cnt = [0, 0]  # counter for pos, neg instances of this attribute
            cur_y = None
            for i in partition:
                cnt[all_att[i, att_idx]] += 1
                if cur_y is None:
                    cur_y = all_att[i, data.label_idx]
                elif cur_y != all_att[i, data.label_idx]:
                    raise ValueError('y has to be the same within each cluster')

            res[att_name] = cnt

        res_dict['y={}'.format(cur_y)].append(res)


    # visualize
    print('========visaulizing clusters=========')
    for k, res_list in res_dict.items():
        print('y = {}'.format(k))

        for att_idx, att_name in enumerate(data.train_data.attr_names):
            if att_idx == data.label_idx or att_idx == data.cor_idx:
                continue

            total_cnt = [0, 0]

            for res in res_list:
                total_cnt[0] += res[att_name][0]
                total_cnt[1] += res[att_name][1]

            print('{:>20}'.format(att_name), end=', ')
            print('{:>5}, {:>5}'.format(total_cnt[0], total_cnt[1]), end=', ')

            for res_id, res in enumerate(res_list):
                print('{:>5}, {:>5}'.format(res[att_name][0], res[att_name][1]),
                                            end=', ')
            print()
            # print('original 0: {:>5}, 1: {:>5}, ratio: {:.4f}'.format(
            #     total_cnt[0], total_cnt[1], total_cnt[0] / total_cnt[1]))

            # for res_id, res in enumerate(res_list):
            #     print('cluster_{} 0 : {:>5}, 1: {:>5}, ratio: {:.4f}'.format(
            #         res_id,
            #         res[att_name][0], res[att_name][1], res[att_name][0] /
            #         res[att_name][1]))


def compute_l2(XS, XQ):
    '''
        Compute the pairwise l2 distance
        @param XS (support x): support_size x ebd_dim
        @param XQ (support x): query_size x ebd_dim

        @return dist: query_size x support_size

    '''
    diff = XS.unsqueeze(0) - XQ.unsqueeze(1)
    dist = torch.norm(diff, dim=2)

    return dist ** 2


def train_partition_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['loss', 'dis_pos', 'dis_neg', 'dis_cross']:
        stats[k] = []

    step = 0

    n_group_pairs = len(train_loaders)
    # every two loaders are paired together

    for batches in zip(*train_loaders):
        # work on each batch
        model['ebd'].train()

        cur_loss = []
        dis_pos = []
        dis_neg = []
        dis_cross = []
        for i in range(n_group_pairs//2):

            x_pos = utils.to_cuda(utils.squeeze_batch(batches[i*2]))['X']
            x_neg = utils.to_cuda(utils.squeeze_batch(batches[i*2+1]))['X']

            min_size = min(len(x_pos), len(x_neg))
            x_pos = x_pos[:min_size]
            x_neg = x_neg[:min_size]

            ebd_pos = model['ebd'](x_pos)
            ebd_neg = model['ebd'](x_neg)

            diff_pos_pos = compute_l2(ebd_pos, ebd_pos)
            diff_pos_neg = compute_l2(ebd_pos, ebd_neg)
            diff_neg_neg = compute_l2(ebd_neg, ebd_neg)

            dis_pos.append(torch.mean(diff_pos_pos.detach()).item())
            dis_neg.append(torch.mean(diff_neg_neg.detach()).item())
            dis_cross.append(torch.mean(diff_pos_neg.detach()).item())

            loss = (
                torch.mean(torch.max(torch.zeros_like(diff_pos_pos),
                                    diff_pos_pos - diff_pos_neg +
                                    torch.ones_like(diff_pos_pos) *
                                     args.thres)))

            loss /= n_group_pairs

            cur_loss.append(loss.item())
            loss.backward()

        opt.step()
        opt.zero_grad()
        loss = sum(cur_loss)

        stats['loss'].append(loss)
        stats['dis_pos'].append(sum(dis_pos) / len(dis_pos))
        stats['dis_neg'].append(sum(dis_neg) / len(dis_neg))
        stats['dis_cross'].append(sum(dis_cross) / len(dis_cross))

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def test_partition_loop(test_loaders, model, ep, args):
    stats = {}
    for k in ['loss']:
        stats[k] = []

    step = 0

    n_group_pairs = len(test_loaders)
    # every two loaders are paired together

    for batches in zip(*test_loaders):
        # work on each batch
        model['ebd'].eval()

        cur_loss = []
        for i in range(n_group_pairs//2):

            x_pos = utils.to_cuda(utils.squeeze_batch(batches[i*2]))['X']
            x_neg = utils.to_cuda(utils.squeeze_batch(batches[i*2+1]))['X']

            min_size = min(len(x_pos), len(x_neg))
            x_pos = x_pos[:min_size]
            x_neg = x_neg[:min_size]

            ebd_pos = model['ebd'](x_pos)
            ebd_neg = model['ebd'](x_neg)

            diff_pos_pos = compute_l2(ebd_pos, ebd_pos)
            diff_pos_neg = compute_l2(ebd_pos, ebd_neg)
            diff_neg_neg = compute_l2(ebd_neg, ebd_neg)

            loss = (
                torch.mean(torch.max(torch.zeros_like(diff_pos_pos),
                                    diff_pos_pos - diff_pos_neg +
                                    torch.ones_like(diff_pos_pos) * args.thres)) +
                torch.mean(torch.max(torch.zeros_like(diff_neg_neg),
                                    diff_neg_neg - diff_pos_neg +
                                    torch.ones_like(diff_neg_neg) * args.thres))
            )

            loss /= n_group_pairs
            cur_loss.append(loss.item())


        loss = sum(cur_loss)

        stats['loss'].append(loss)

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def train_dro_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['worst_loss', 'avg_loss', 'worst_acc', 'avg_acc']:
        stats[k] = []

    step = 0
    for batches in zip(*train_loaders):
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        x, y = [], []

        for batch in batches:
            batch = utils.to_cuda(utils.squeeze_batch(batch))
            x.append(batch['X'])
            y.append(batch['Y'])

        if is_textdata(args.dataset):
            # text models have varying length between batches
            pred = []
            for cur_x in x:
                pred.append(model['clf_all'](model['ebd'](cur_x)))
            pred = torch.cat(pred, dim=0)
        else:
            pred = model['clf_all'](model['ebd'](torch.cat(x, dim=0)))

        cur_idx = 0

        avg_loss = 0
        avg_acc = 0
        worst_loss = 0
        worst_acc = 0

        for cur_true in y:
            cur_pred = pred[cur_idx:cur_idx+len(cur_true)]
            cur_idx += len(cur_true)

            loss = F.cross_entropy(cur_pred, cur_true)
            acc = torch.mean((torch.argmax(cur_pred, dim=1) == cur_true).float()).item()

            avg_loss += loss.item()
            avg_acc += acc

            if loss.item() > worst_loss:
                worst_loss = loss
                worst_acc = acc

        opt.zero_grad()
        worst_loss.backward()
        opt.step()

        avg_loss /= len(y)
        avg_acc /= len(y)

        stats['avg_acc'].append(avg_acc)
        stats['avg_loss'].append(avg_loss)
        stats['worst_acc'].append(worst_acc)
        stats['worst_loss'].append(worst_loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def train_loop(train_loader, model, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss']:
        stats[k] = []

    step = 0
    for batch in train_loader:
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']

        acc, loss = model['clf_all'](x, y, return_pred=False,
                                     grad_penalty=False)

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats


def test_loop(test_loader, model, ep, args, return_idx=False, att_idx_dict=None):
    loss_list = []
    true, pred, cor = [], [], []
    if (att_idx_dict is not None) or return_idx:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']
        c = batch['C']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)
        cor.append(c)

        if (att_idx_dict is not None) or return_idx:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if return_idx:
        cor = torch.cat(cor).tolist()
        true = true.tolist()
        pred = pred.tolist()
        idx = torch.cat(idx).tolist()

        # split correct and wrong idx
        correct_idx, wrong_idx = [], []

        # compute correlation between cor and y for analysis
        correct_cor, wrong_cor = [], []
        correct_y, wrong_y = [], []

        for i, y, y_hat, c in zip(idx, true, pred, cor):
            if y == y_hat:
                correct_idx.append(i)
                correct_cor.append(c)
                correct_y.append(y)
            else:
                wrong_idx.append(i)
                wrong_cor.append(c)
                wrong_y.append(y)

        return {
            'acc': acc,
            'loss': loss,
            'correct_idx': correct_idx,
            'correct_cor': correct_cor,
            'correct_y': correct_y,
            'wrong_idx': wrong_idx,
            'wrong_cor': wrong_cor,
            'wrong_y': wrong_y,
        }

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }


def cluster_loop(test_loader, model, args):
    model['ebd'].eval()
    groups = {}

    for batch in test_loader:
        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x_s = model['ebd'](batch['X']).cpu().numpy()
        # x_s = model['ebd'](batch['X'], True).cpu().numpy()
        # exit(0)

        y_s = batch['Y'].cpu().numpy()
        c_s = batch['C'].cpu().numpy()
        idx_s = batch['idx'].cpu().numpy()

        for x, y, c, idx in zip(x_s, y_s, c_s, idx_s):
            if int(y) not in groups:
                groups[int(y)] = {
                    'x': [],
                    'c': [],
                    'idx': [],
                }
            groups[int(y)]['x'].append(x)
            groups[int(y)]['c'].append(c)
            groups[int(y)]['idx'].append(idx)

    # print('saving to pickle')
    # pickle.dump(groups, open("groups_visual_all.p", "wb" ))
    # print('done')
    # exit(0)

    clusters = []
    clustering_metrics = [metrics.homogeneity_score,
                          metrics.completeness_score,
                          metrics.v_measure_score,
                          metrics.adjusted_rand_score,
                          metrics.adjusted_mutual_info_score]
    clustering_results = []

    for k, v in groups.items():
        x = np.stack(v['x'], axis=0)
        print(x.shape)

        cur_clusters = {}
        cur_cs = {}

        centroid, label, inertia = k_means(x, args.num_clusters)
        print(inertia)

        # centroid, label, inertia = k_means(x, 3)
        # print(inertia)

        metric_c = np.stack(v['c'], axis=0)
        clustering_results.append([m(metric_c, label) for m in
                                   clustering_metrics])

        for cluster_id, idx, c in zip(label, v['idx'], v['c']):
            if cluster_id not in cur_clusters:
                cur_clusters[cluster_id] = []
                cur_cs[cluster_id] = []

            cur_clusters[cluster_id].append(idx)
            cur_cs[cluster_id].append(c)

        for cluster_id, cluster in cur_clusters.items():
            clusters.append(cluster)
            cnt = Counter(cur_cs[cluster_id])
            print('size: {}, color '.format(len(cur_cs[cluster_id])), end='')
            for c, cur_cnt in sorted(cnt.items()):
                print('{}={:.2f}, '.format(c, cur_cnt / len(cur_cs[cluster_id])),
                      end='')
            print()

    clustering_results = np.array(clustering_results)
    print('clustering metrics')
    print(clustering_results.shape)
    print(np.mean(clustering_results, axis=0))

    return clusters


def print_partition_res(train_res, val_res, ep):
    print(("epoch {epoch}, train {loss} {train_loss:>10.7f} pos {dis_pos:>10.5f} "
           "neg {dis_neg:>10.5f} cross {dis_cross:>10.5f} "
           "val {loss} {val_loss:>10.7f}").format(
               epoch=ep,
               loss=colored("loss", "yellow"),
               train_loss=train_res["loss"],
               val_loss=val_res["loss"],
               dis_pos=train_res["dis_pos"],
               dis_neg=train_res["dis_neg"],
               dis_cross=train_res["dis_cross"],
           ), flush=True)


def print_res(train_res, val_res, ep):
    print(("epoch {epoch}, train {acc} {train_acc:>7.4f} {train_worst_acc:>7.4f} "
           "{loss} {train_loss:>10.7f} {train_worst_loss:>10.7f} "
           "val {acc} {val_acc:>10.7f}, {loss} {val_loss:>10.7f}").format(
               epoch=ep,
               acc=colored("acc", "blue"),
               loss=colored("loss", "yellow"),
               regret=colored("regret", "red"),
               train_acc=train_res["avg_acc"],
               train_worst_acc=train_res["worst_acc"],
               train_loss=train_res["avg_loss"],
               train_worst_loss=train_res["worst_loss"],
               val_acc=val_res["acc"],
               val_loss=val_res["loss"]), flush=True)


def print_pretrain_res(train_res, test_res, ep, i):
    print(("petrain {i}, epoch {epoch}, train {acc} {train_acc:>7.4f} "
           "{loss} {train_loss:>7.4f}, "
           "val {acc} {test_acc:>7.4f}, {loss} {test_loss:>7.4f} ").format(
               epoch=ep,
               i = i,
               acc=colored("acc", "blue"),
               loss=colored("loss", "yellow"),
               ebd=colored("ebd", "red"),
               train_acc=train_res["acc"],
               train_loss=train_res["loss"],
               test_acc=test_res["acc"],
               test_loss=test_res["loss"]), flush=True)


def ours(train_data, test_data, model, opt, args, partition_model=None,
         train_partition_loaders=None, val_partition_loaders=None,
         train_ebd=None):

    # loading data in testing mode
    test_loaders = []
    for i in range(4):
        test_loaders.append(DataLoader(
            test_data,
            sampler=EnvSampler(-1, args.batch_size, i,
                               test_data.envs[i]['idx_list']),
        num_workers=10))

    val_loaders = []

    ######
    # creating split for the train data
    # if no partition model is available, use PI to create partitions
    # otherwise use clustering to generate partitions
    ######
    if partition_model is None:
        # no partition model available
        # this is source task, use PI algorithm to contrast the training
        # environments

        # create a training loader for each train env
        train_loaders = []
        for i in range(2):
            train_loaders.append(DataLoader(
                train_data,
                sampler=EnvSampler(args.num_batches, args.batch_size, i,
                                   train_data.envs[i]['idx_list']),
            num_workers=10))

        # training the environment-specific classifier
        models = []
        for i in range(2):
            if hasattr(train_data, 'vocab'):
                cur_model, cur_opt = get_model(args, train_data.vocab)
            else:
                cur_model, cur_opt = get_model(args)

            print("{}, Start training classifier on train env {}".format(
                datetime.datetime.now().strftime('%02y/%02m/%02d %H:%M:%S'), i),
                  flush=True)

            best_acc = -1
            best_model = {}
            cycle = 0

            # start training the env specific model
            for ep in range(args.num_epochs):
                train_res = train_loop(train_loaders[i], cur_model, cur_opt, ep, args)

                with torch.no_grad():
                    # evaluate on the other training environment
                    val_res = test_loop(train_loaders[1-i], cur_model, ep, args)

                print_pretrain_res(train_res, val_res, ep, i)

                if val_res['acc'] > best_acc:
                    best_acc = val_res['acc']
                    cycle = 0
                    # save best ebd
                    for k in 'ebd', 'clf_all':
                        best_model[k] = copy.deepcopy(cur_model[k].state_dict())
                else:
                    cycle += 1

                if cycle == args.patience:
                    break

            # load best model
            for k in 'ebd', 'clf_all':
                cur_model[k].load_state_dict(best_model[k])

            models.append(cur_model)

        # load training data in test mode
        test_train_loaders = []
        for i in range(2):
            test_train_loaders.append(DataLoader(train_data,
                sampler=EnvSampler(-1, args.batch_size, i,
                                   train_data.envs[i]['idx_list']), num_workers=10))

        # split the dataset based on the model predictions
        pretrain_res = []
        pretrain_res.append(test_loop(test_train_loaders[0], models[1], ep, args, True))
        pretrain_res.append(test_loop(test_train_loaders[1], models[0], ep, args, True))

        # if args.dataset[:6] == 'celeba':
        #     print_group_stats(pretrain_res, train_data)
        # else:
        #     print_group_stats(pretrain_res)

        # train a new unbiased model through dro
        train_loaders = []
        print('\n######################\nCreate New Groups')
        for i in range(len(pretrain_res)):
            train_loaders.append(DataLoader(
                train_data, sampler=EnvSampler(args.num_batches, args.batch_size, i,
                                               pretrain_res[i]['correct_idx']),
                num_workers=10))

            train_loaders.append(DataLoader(
                train_data, sampler=EnvSampler(args.num_batches, args.batch_size, i,
                                               pretrain_res[i]['wrong_idx']),
            num_workers=10))

    else:
        # use trained partition_model to cluster the input
        # and create the groups
        test_train_loader = DataLoader(
            train_data, sampler=EnvSampler(-1, args.batch_size, 0,
                                           train_data.envs[0]['idx_list']),
            num_workers=10)

        print('partition the train data')
        with torch.no_grad():
            partition_res = cluster_loop(test_train_loader, partition_model,
                                         args)

            print(visualize_cluster(train_data, 0, partition_res))






        train_loaders = []
        for group in partition_res:
            train_loaders.append(DataLoader(
                train_data, sampler=EnvSampler(args.num_batches,
                                               args.batch_size, 0, group),
                num_workers=int(10 / args.num_clusters * 2)))

        if not ('MNIST' in args.dataset):
            # use trained partition_model to cluster the valdiation input!!!
            # and create the groups
            # MNIST has artifical label noise
            with torch.no_grad():
                partition_res = cluster_loop(test_loaders[2], partition_model, args)

            for group in partition_res:
                val_loaders.append(DataLoader(
                    test_data, sampler=EnvSampler(args.num_batches,
                                                   args.batch_size, 2, group),
                    num_workers=int(10 / args.num_clusters * 2)))

    ######
    # given the splits, if we are already given the partition model
    # this means we are dealing with the target task
    # we train a stable classifier and return its result.
    # otherwise we skip this part by filling a placeholder res dict
    ######
    if partition_model is None and args.tar_method == 'ours':
        # our method doesn't need to learn a robust source classifier
        # define a placeholder res
        res = {
            'train': train_res,
            'val': train_res,
            'test': train_res,
            'partition': None,
        }
    else:
        # start training a stable classifier by minimizing the worst case risk
        # across all groups
        best_acc = -1
        best_val_res = None
        best_model = {}
        cycle = 0
        for ep in range(args.num_epochs):
            train_res = train_dro_loop(train_loaders, model, opt, ep, args)

            with torch.no_grad():
                # validation
                if len(val_loaders) == 0:
                    val_res = test_loop(test_loaders[2], model, ep, args)
                else:
                    # use the worst-cluster acc in the validation data for early
                    # stopping
                    val_res = {'acc': [], 'loss': []}
                    for val_loader in val_loaders:
                        cur_val_res = test_loop(val_loader, model, ep, args)
                        val_res['acc'].append(cur_val_res['acc'])
                        val_res['loss'].append(cur_val_res['loss'])
                    val_res['acc'] = min(val_res['acc'])
                    val_res['loss'] = max(val_res['loss'])

            print_res(train_res, val_res, ep)

            if min(train_res['worst_acc'], val_res['acc']) > best_acc:
                best_acc = min(train_res['worst_acc'], val_res['acc'])
                best_val_res = val_res
                best_train_res = train_res
                cycle = 0
                # save best ebd
                for k in 'ebd', 'clf_all':
                    best_model[k] = copy.deepcopy(model[k].state_dict())
            else:
                cycle += 1

            if cycle == args.patience:
                break

        # load best model
        for k in 'ebd', 'clf_all':
            model[k].load_state_dict(best_model[k])

        # get the results
        test_res = test_loop(test_loaders[3], model, ep, args,
                             att_idx_dict=test_data.test_att_idx_dict)
        print('Best train')
        print(train_res)
        print('Best val')
        val_res = best_val_res
        print(val_res)
        print('Test')
        print(test_res)

        # this is inference on target task, do not need to retrain the partition
        # model
        res = {
            'train': train_res,
            'val': val_res,
            'test': test_res,
            'partition': None,
        }

    if partition_model is not None:
        return res

    if partition_model is not None or args.tar_method != 'ours':
        # if we are learning on the source task,
        # this happens if we use re-use, finetune, we don't need to learn the
        # unstable feature

        # if we are learning on the target task,
        # we return the performance
        return res, None, None

    ########################################################################
    # learn the unstable features from the groups
    ########################################################################
    print('\n######################\nLearning the partition model')
    if hasattr(train_data, 'vocab'):
        partition_model, opt = get_model(args, train_data.vocab)
    else:
        partition_model, opt = get_model(args)

    # create data loader from each partition x label
    if train_partition_loaders is None:
        train_partition_loaders, val_partition_loaders = [], []

    print('\n######################\nCreate New Groups')
    for env in range(len(pretrain_res)):
        # look at each environment, each label
        groups = {}
        label_list = train_data.get_all_y(env)
        for idx, label in zip(train_data.envs[env]['idx_list'], label_list):
            if label not in groups:
                groups[label] = {
                    'yes': [],
                    'no': [],
                }

            if idx in pretrain_res[env]['correct_idx']:
                groups[label]['yes'].append(idx)
            elif idx in pretrain_res[env]['wrong_idx']:
                groups[label]['no'].append(idx)
            else:
                raise ValueError('unknown idx')

        for group in groups.values():
            # each group contains pos (yes) and neg (no)
            # each group corresponds to examples from one label value
            # use 70% for training, 30% for validation
            if min(len(group['yes']), len(group['no'])) * 0.3 < 1:
                continue

            for k in ['yes', 'no']:
                train_partition_loaders.append(
                    DataLoader(train_data,
                               sampler=EnvSampler(args.num_batches, args.batch_size,
                                                  env,
                                                  group[k][:int(len(group[k])*0.7)]),
                               num_workers=2),
                )

                val_partition_loaders.append(
                    DataLoader(train_data,
                               sampler=EnvSampler(args.num_batches, args.batch_size,
                                                  env,
                                                  group[k][int(len(group[k])*0.7):]),
                               num_workers=2),
                )

    if len(args.dataset_remaining) != 0:
        # we still have another source task to train on.
        print('early exit')
        return res, train_partition_loaders, val_partition_loaders


    # start training the unstable feature representation
    best_loss = float('inf')
    best_val_res = None
    best_model = {}
    cycle = 0
    for ep in range(args.num_epochs):
        train_res = train_partition_loop(train_partition_loaders, partition_model, opt,
                                         ep, args)

        with torch.no_grad():
            # validation
            val_res = test_partition_loop(val_partition_loaders, partition_model, ep,
                                          args)

        print_partition_res(train_res, val_res, ep)
        if val_res['loss'] < best_loss:
            best_loss = val_res['loss']
            best_val_res = val_res
            best_train_res = train_res
            cycle = 0
            # save best ebd
            for k in 'ebd', 'clf_all':
                best_model[k] = copy.deepcopy(partition_model[k].state_dict())
        else:
            cycle += 1

        if cycle == args.patience:
            break

    # load best model
    for k in 'ebd', 'clf_all':
        partition_model[k].load_state_dict(best_model[k])

    res.update({
        'metric_train': train_res,
        'metric_val': val_res,
        'partition': partition_model
    })
    return res, train_partition_loaders, val_partition_loaders
