import pickle
import os
from xmeta.utils.sift import SiftFeature
from xmeta.utils.data import ImpureTasksets, get_tasksets, pollute_tasks
from xmeta.utils.seed import set_seed
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import itertools


def draw_df_row(n, df, index_dict, bins=20, highlight=()):
    train_task_idces =\
        df[df['test_task_idx'] == n]['train_task_idx'].values.squeeze().tolist()
    train_task_scores =\
        df[df['test_task_idx'] == n]['train_task_score'].values.squeeze().tolist()
    draw_hist(train_task_idces, train_task_scores, index_dict,
              bins=bins, highlight=highlight)


def draw_hist(train_task_idces, train_task_scores, index_dict, bins=20, highlight=(),
              ax=None):
    index_dict = index_dict.copy()
    for k in ['train_noise_tasks', 'train_shuffle_tasks', 'train_dark_tasks', 'train_recolor_tasks', 'train_bgr_tasks']:
        if k not in index_dict:
            index_dict[k] = []

    noise_task_scores = [train_task_scores[train_task_idces.index(idx)]
                         for idx in index_dict['train_noise_tasks']
                         if idx not in highlight
                         ]
    shuffle_task_scores = [train_task_scores[train_task_idces.index(idx)]
                           for idx in index_dict['train_shuffle_tasks']
                           if idx not in highlight]
    dark_task_scores = [train_task_scores[train_task_idces.index(idx)]
                        for idx in index_dict['train_dark_tasks']
                        if idx not in highlight]
    recolor_task_scores = [train_task_scores[train_task_idces.index(idx)]
                           for idx in index_dict['train_recolor_tasks']
                           if idx not in highlight]
    bgr_task_scores = [train_task_scores[train_task_idces.index(idx)]
                       for idx in index_dict['train_bgr_tasks']
                       if idx not in highlight]
    normal_task_scores = [score
                          for score, idx in zip(train_task_scores, train_task_idces)
                          if (idx not in index_dict['train_noise_tasks']) and
                          (idx not in index_dict['train_shuffle_tasks']) and
                          (idx not in index_dict['train_dark_tasks']) and
                          (idx not in index_dict['train_recolor_tasks']) and
                          (idx not in index_dict['train_bgr_tasks']) and
                          (idx not in highlight)]
    highlight_scores = [train_task_scores[train_task_idces.index(idx)]
                        for idx in highlight]
    
    if ax is None:
        ax = plt

    xrange = [min(train_task_scores), max(train_task_scores)]
    if len(normal_task_scores) > 0:
        ax.hist(normal_task_scores, alpha=0.5, color='k', label='normal_task',
                range=xrange, bins=bins)
    if len(noise_task_scores) > 0:
        ax.hist(noise_task_scores, alpha=0.5, color='g', label='noise_image',
                range=xrange, bins=bins)
    if len(shuffle_task_scores) > 0:
        ax.hist(shuffle_task_scores, alpha=0.5, color='r', label='shuffled_label',
                range=xrange, bins=bins)
    if len(dark_task_scores) > 0:
        ax.hist(dark_task_scores, alpha=0.5, color='c', label='dark_image',
                range=xrange, bins=bins)
    if len(recolor_task_scores) > 0:
        ax.hist(recolor_task_scores, alpha=0.5, color='b', label='recolor_image',
                range=xrange, bins=bins)
    if len(bgr_task_scores) > 0:
        ax.hist(bgr_task_scores, alpha=0.5, color='b', label='bgr_image',
                range=xrange, bins=bins)
    if len(highlight) > 0:
        plt.hist(highlight_scores, alpha=1.0, color='r', label=str(highlight),
                 range=xrange, bins=bins)
    
    ax.xlabel('source_task_score')
    ax.ylabel('count')
    ax.legend()
    

def normal_idxes(index_dict, num_tasks):
    normal_task_idx = [idx for idx in range(num_tasks)
                       if (idx not in index_dict['train_noise_tasks']) and
                          (idx not in index_dict['train_shuffle_tasks']) and
                          (idx not in index_dict['train_dark_tasks']) and
                          (idx not in index_dict['train_recolor_tasks']) and
                          (idx not in index_dict['train_bgr_tasks'])
                       ]
    return normal_task_idx


def _normalize(df: pd.DataFrame, score_col: str, ref_col: str, std_col: str):
    ret = df.apply(lambda row: [(s - row[ref_col]) / row[std_col]
                                for s in row[score_col]],
                   axis=1
                   )
    return ret


def _best(row, best=1, filter_idx=None):
    _pos = np.argsort(row.train_task_score)[::-1][:best]
    idxes = [idx for ii, idx in enumerate(row.train_task_idx) if ii in _pos]
    if filter_idx is not None:
        idxes = [idx for idx in idxes if idx in filter_idx]
    ret = [score for idx, score in zip(row.train_task_idx, row.train_task_score)
           if idx in idxes]
    return ret


def _worst(row, worst=1, filter_idx=None):
    _pos = np.argsort(row.train_task_score)[:worst]
    idxes = [idx for ii, idx in enumerate(row.train_task_idx) if ii in _pos]
    if filter_idx is not None:
        idxes = [idx for idx in idxes if idx in filter_idx]
    ret = [score for idx, score in zip(row.train_task_idx, row.train_task_score)
           if idx in idxes]
    return ret


def _plot_list(ax, x: pd.Series, y: pd.Series, **kwargs):
    for _x, _y in zip(x, y):
        if len(_y) > 0:
            ax.scatter([_x] * len(_y), _y, **kwargs)


def plot_score_dist(df: pd.DataFrame, index_dict: dict, num_tasks: int,
                    x_col: str = 'test_error', plot_index: bool = True,
                    normalize: bool = False,
                    best: int = None, worst: int = None, ax=None, stats_only=False):
    if ax is None:
        ax = plt
    
    df = df.copy()
    df = df.sort_values([x_col], ascending=True).reset_index()
    if plot_index:
        x = df.index
    else:
        x = df[x_col]

    ret_cols = ['test_task_idx',
                'test_error', 'test_accuracy',
                'adaptation_error', 'adaptation_accuracy',
                'zeroshot_error', 'zeroshot_accuracy']
    if x_col not in ret_cols:
        ret_cols += [x_col]

    df['mean'] = df['train_task_score'].apply(np.mean)
    df['median'] = df['train_task_score'].apply(np.median)
    df['std'] = df['train_task_score'].apply(np.std)
    if normalize:
        df['train_task_score'] = _normalize(df, score_col='train_task_score',
                                            ref_col='mean', std_col='std')

    if best is not None:
        df['best'] =\
            df['train_task_score'].apply(lambda x: np.sort(x)[::-1][:best].tolist())

    if worst is not None:
        df['worst'] =\
            df['train_task_score'].apply(lambda x: np.sort(x)[:worst].tolist())

    normal = normal_idxes(index_dict, num_tasks)
    if len(normal) > 0:
        df['normal_scores'] = df.apply(lambda row:
                                       [x for i, x
                                        in zip(row.train_task_idx, row.train_task_score)
                                        if i in normal
                                        ],
                                       axis=1
                                       )
        df['normal_mean'] = df['normal_scores'].apply(np.mean)
        df['normal_std'] = df['normal_scores'].apply(np.std)
        df['normal_median'] = df['normal_scores'].apply(np.median)
        
    if len(index_dict['train_noise_tasks']) > 0:
        df['noise_scores'] = df.apply(
            lambda row:
            [x for i, x in zip(row.train_task_idx, row.train_task_score)
             if i in index_dict['train_noise_tasks']],
            axis=1)
        df['noise_mean'] = df['noise_scores'].apply(np.mean)
        df['noise_std'] = df['noise_scores'].apply(np.std)
        df['noise_median'] = df['noise_scores'].apply(np.median)
    
    if len(index_dict['train_shuffle_tasks']) > 0:
        df['shuffle_scores'] = df.apply(
            lambda row:
            [x for i, x in zip(row.train_task_idx, row.train_task_score)
             if i in index_dict['train_shuffle_tasks']],
            axis=1)
        df['shuffle_mean'] = df['shuffle_scores'].apply(np.mean)
        df['shuffle_std'] = df['shuffle_scores'].apply(np.std)
        df['shuffle_median'] = df['shuffle_scores'].apply(np.median)
    
    if len(index_dict['train_dark_tasks']) > 0:
        df['dark_scores'] = df.apply(
            lambda row:
            [x for i, x in zip(row.train_task_idx, row.train_task_score)
             if i in index_dict['train_dark_tasks']],
            axis=1)
        df['dark_mean'] = df['dark_scores'].apply(np.mean)
        df['dark_std'] = df['dark_scores'].apply(np.std)
        df['dark_median'] = df['dark_scores'].apply(np.median)

    if len(index_dict['train_recolor_tasks']) > 0:
        df['recolor_scores'] = df.apply(
            lambda row:
            [x for i, x in zip(row.train_task_idx, row.train_task_score)
             if i in index_dict['train_recolor_tasks']],
            axis=1)
        df['recolor_mean'] = df['recolor_scores'].apply(np.mean)
        df['recolor_std'] = df['recolor_scores'].apply(np.std)
        df['recolor_median'] = df['recolor_scores'].apply(np.median)
    
    if len(index_dict['train_bgr_tasks']) > 0:
        df['bgr_scores'] = df.apply(
            lambda row:
            [x for i, x in zip(row.train_task_idx, row.train_task_score)
             if i in index_dict['train_bgr_tasks']],
            axis=1)
        df['bgr_mean'] = df['bgr_scores'].apply(np.mean)
        df['bgr_std'] = df['bgr_scores'].apply(np.std)
        df['bgr_median'] = df['bgr_scores'].apply(np.median)
    
    if best is not None:
        df['normal_best'] = df.apply(
            lambda row: _best(row, best=best, filter_idx=normal), axis=1)

        df['noise_best'] = df.apply(
            lambda row: _best(row, best=best,
                              filter_idx=index_dict['train_noise_tasks']), axis=1)
        df['shuffle_best'] = df.apply(
            lambda row: _best(row, best=best,
                              filter_idx=index_dict['train_shuffle_tasks']), axis=1)
        df['dark_best'] = df.apply(
            lambda row: _best(row, best=best,
                              filter_idx=index_dict['train_dark_tasks']), axis=1)
        df['recolor_best'] = df.apply(
            lambda row: _best(row, best=best,
                              filter_idx=index_dict['train_recolor_tasks']), axis=1)
        df['bgr_best'] = df.apply(
            lambda row: _best(row, best=best,
                              filter_idx=index_dict['train_bgr_tasks']), axis=1)
    if worst is not None:
        df['normal_worst'] = df.apply(
            lambda row: _worst(row, worst=worst, filter_idx=normal), axis=1)
        
        df['noise_worst'] = df.apply(
            lambda row: _worst(row, worst=worst,
                               filter_idx=index_dict['train_noise_tasks']), axis=1)
        df['shufle_worst'] = df.apply(
            lambda row: _worst(row, worst=worst,
                               filter_idx=index_dict['train_shuffle_tasks']), axis=1)
        df['dark_worst'] = df.apply(
            lambda row: _worst(row, worst=worst,
                               filter_idx=index_dict['train_dark_tasks']), axis=1)
        df['recolor_worst'] = df.apply(
            lambda row: _worst(row, worst=worst,
                               filter_idx=index_dict['train_recolor_tasks']), axis=1)
        df['bgr_worst'] = df.apply(
            lambda row: _worst(row, worst=worst,
                               filter_idx=index_dict['train_bgr_tasks']), axis=1)
    
    if ('noise_scores' in df) and ('normal_scores' in df):
        print(
            f'n_test: {len(df)} '
            f'n_proper_order(mean): {(df["noise_mean"] < df["normal_mean"]).sum()}'
            )
        print(
            f'n_test: {len(df)} '
            f'n_proper_order(median): {(df["noise_median"] < df["normal_median"]).sum()}'
            )
    if stats_only:
        return
    
    if 'normal_scores' in df:
        ax.plot(x, df['normal_mean'], marker=None,
                color='black', linewidth=1)
        ax.plot(x, df['normal_mean'] + df['normal_std'], marker=None,
                color='black', linewidth=1)
        ax.plot(x, df['normal_mean'] - df['normal_std'], marker=None,
                color='black', linewidth=1)
        ret_cols =\
            ret_cols + ['normal_scores', 'normal_mean', 'normal_std', 'normal_median']
        if best is not None:
            _plot_list(ax, x, df['normal_best'], color='blue', s=5)
            ret_cols = ret_cols + ['normal_best']
        if worst is not None:
            _plot_list(ax, x, df['normal_worst'], color='black', s=5)
            ret_cols = ret_cols + ['normal_worst']
    
    if 'noise_scores' in df:
        # ax.errorbar(x=x, y=df['noise_mean'], yerr=df['noise_std'],
        #             fmt='o', markersize=2, capsize=2)
        ax.plot(x, df['noise_mean'], marker=None,
                color='green', linewidth=1)
        ax.plot(x, df['noise_mean'] + df['noise_std'], marker=None,
                color='green', linewidth=1)
        ax.plot(x, df['noise_mean'] - df['noise_std'], marker=None,
                color='green', linewidth=1)

        ret_cols = ret_cols + ['noise_scores', 'noise_mean', 'noise_std', 'noise_median']
        if best is not None:
            _plot_list(ax, x, df['noise_best'], color='green', s=5)
            ret_cols = ret_cols + ['noise_best']
        if worst is not None:
            _plot_list(ax, x, df['noise_worst'], color='red', s=5)
            ret_cols = ret_cols + ['noise_worst']
    if 'shuffle_scores' in df:
        ax.errorbar(x=x, y=df['shuffle_mean'], yerr=df['shuffle_std'],
                    fmt='o', markersize=2, capsize=2)
        ret_cols = ret_cols + ['shuffle_scores', 'shuffle_mean', 'shuffle_std', 
                               'shuffle_median']
        if best is not None:
            _plot_list(ax, x, df['shuffle_best'], color='green', s=5)
            ret_cols = ret_cols + ['shuffle_best']
        if worst is not None:
            _plot_list(ax, x, df['shuffle_worst'], color='red', s=5)
            ret_cols = ret_cols + ['shuffle_worst']
    if 'dark_scores' in df:
        ax.errorbar(x=x, y=df['dark_mean'], yerr=df['dark_std'],
                    fmt='o', markersize=2, capsize=2)
        ret_cols = ret_cols + ['dark_scores', 'dark_mean', 'dark_std', 'dark_median']
        if best is not None:
            _plot_list(ax, x, df['dark_best'], color='green', s=5)
            ret_cols = ret_cols + ['dark_best']
        if worst is not None:
            _plot_list(ax, x, df['dark_worst'], color='red', s=5)
            ret_cols = ret_cols + ['dark_worst']
    if 'recolor_scores' in df:
        ax.errorbar(x=x, y=df['recolor_mean'], yerr=df['recolor_std'],
                    fmt='o', markersize=2, capsize=2)
        ret_cols =\
            ret_cols + ['recolor_scores', 'recolor_mean', 'recolor_std', 'recolor_median']
        if best is not None:
            _plot_list(ax, x, df['recolor_best'], color='green', s=5)
            ret_cols = ret_cols + ['recolor_best']
        if worst is not None:
            _plot_list(ax, x, df['recolor_worst'], color='red', s=5)
            ret_cols = ret_cols + ['recolor_worst']
    if 'bgr_scores' in df:
        ax.errorbar(x=x, y=df['bgr_mean'], yerr=df['bgr_std'],
                    fmt='o', markersize=2, capsize=2)
        ret_cols =\
            ret_cols + ['bgr_scores', 'bgr_mean', 'bgr_std', 'bgr_median']
        if best is not None:
            _plot_list(ax, x, df['bgr_best'], color='green', s=5)
            ret_cols = ret_cols + ['bgr_best']
        if worst is not None:
            _plot_list(ax, x, df['bgr_worst'], color='red', s=5)
            ret_cols = ret_cols + ['bgr_worst']

    ret_cols = [x for x in ret_cols if x in df.columns]
    return df[ret_cols]
        

def permute_labels(a: np.ndarray, labels: list):
    ways = len(labels)
    ret = a
    for x, y in zip(range(ways), labels):
        ret = np.where(a == x, y, ret) 
    return ret


def compare_arrays(a: np.ndarray, b: np.ndarray):
    ways = a.max() + 1
    max_match = -1
    best_permutation = None
    for labels in itertools.permutations(range(ways)):
        _labels = list(labels)
        _b = permute_labels(b, _labels)
        match = (a == _b).sum()
        if match > max_match:
            max_match = match
            best_permutation = labels
    return max_match, best_permutation


def compare_tensors(a: torch.Tensor, b: torch.Tensor):
    a = a.detach().to('cpu').numpy()
    b = b.detach().to('cpu').numpy()
    return compare_arrays(a, b)


def compare_tasksets(a, b, idxes: list):
    scores = []
    for idx in idxes:
        scores.append(compare_tensors(a[idx][1], b[idx][1]))

    return scores


def aggregate_shuffle_scores(df: pd.DataFrame, index_dict: dict):
    assert len(index_dict['train_shuffle_tasks']) > 0

    quality = {i: [x[0]] for i, x
               in zip(index_dict['train_shuffle_tasks'], index_dict['num_correct_labels'])}
    df_q = pd.DataFrame.from_dict(quality, orient='index').sort_index()
    df_q.columns = ['n_correct_label']
    
    score = {i: [[]] for i in index_dict['train_shuffle_tasks']}
    rank = {i: [[]] for i in index_dict['train_shuffle_tasks']}
    for row in df.itertuples():
        for j, ix in enumerate(zip(row.train_task_idx, row.train_task_score)):
            i, x = ix
            if i in index_dict['train_shuffle_tasks']:
                score[i][0].append(x)
                rank[i][0].append(j)
    
    df_score = pd.DataFrame.from_dict(score, orient='index').sort_index()
    df_score.columns = ['score']
    df_score['mean_score'] = df_score['score'].apply(np.mean)
    df_score['max_score'] = df_score['score'].apply(max)
    df_score['min_score'] = df_score['score'].apply(min)

    df_rank = pd.DataFrame.from_dict(rank, orient='index').sort_index()
    df_rank.columns = ['rank']
    df_rank['mean_rank'] = df_rank['rank'].apply(np.mean)
    df_rank['max_rank'] = df_rank['rank'].apply(max)
    df_rank['min_rank'] = df_rank['rank'].apply(min)
    
    df_ret = pd.concat([df_score, df_rank, df_q], axis=1)
    return df_ret
