import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


dataset_converter = {
        'imagenet-v2': 'IN-V2',
        'imagenet-a': 'IN-A',
        'imagenet-o': 'IN-O',
        'imagenet-r': 'IN-R',
        'imagenet-sketch': 'IN-S',
        'imagenet-c': 'IN-C',
        'imagenet-cartoon': 'IN-Cartoon',
        'imagenet-drawing': 'IN-Drawing',
        'objectnet': 'ObjNet',
        'objectnet-1.0': 'ObjNet',
        'objectnet-v2': 'ObjNet'
        }

#in1k_base_32
pretrained_converter = {
        'imagenet': 'IN-21K + AugReg',
        'laion': 'LAION-2B',
        'in1k-orig': 'IN-1K',
        'in1k_orig': 'IN-1K',
        'openai': 'OpenAI',
        'in1k': 'IN-1K + AugReg',
        'orig': 'IN-21K',
        'sam': 'IN-1K + SAM',
        'miil': 'IN-21K-P'
        }

arch_converter = {
        'base_16': 'ViT-B/16',
        'base_32': 'ViT-B/32',
        'small_16': 'ViT-S/16',
        'small_32': 'ViT-S/32',
        'large_16': 'ViT-L/16',
        'resnet18': 'ResNet-18',
        'resnet50': 'ResNet-50',
        'resnet18_base_32': 'ResNet-18',
        'resnet50_base_32': 'ResNet-50',
        }

def converter(models):
    arches = []
    pretraineds = []
    models = [model for model in models if '_' in model]

    for model in models:
        print(model)
        pretrained = model.replace('in1k_orig', 'in1k-orig').split('_')[0]
        pretraineds.append(pretrained_converter[pretrained])
        for key, value in arch_converter.items():
            if 'resnet' in model:
                arches.append(arch_converter[model.split('_')[1]])
                break
            elif key in model:
                arches.append(value)
    return models, arches, pretraineds

name_converter = {
        'FT': 'FT',
        'HeadOnly': 'Linear Probing',
        'lora': 'LoRA~\\citep{hu2021lora}',
        'lwf': 'LwF~\\citep{li2017learning}',
        'ewc': 'EWC~\\citep{kirkpatrick2017overcoming}',
        'Soup-FT-ewc-lwf': 'Model Soup old~\\citep{wortsman2022model}',
        'Soup-PRE-FT-ewc-lwf': 'Model Soup~\\citep{wortsman2022model}',
        'Soup-PRE-FT': 'WiSE-FT~\\citep{wortsman2022robust}',
        'Prompter': 'Visual Prompt~\\citep{bahng2022exploring}',
        'LPFT': 'LP-FT~\\citep{kumar2022fine}',
                  }
name_converter = {
        'FT': 'FT',
        'HeadOnly': 'Linear Probing',
        'lora': 'LoRA',
        'lwf': 'LwF',
        'ewc': 'EWC',
        'Soup-FT-ewc-lwf': 'Model Soup old~\\citep{wortsman2022model}',
        'Soup-PRE-FT-ewc-lwf': 'MS:PRE-FT-EWC-LwF',
        'Soup-PRE-FT': 'WiSE-FT',
        'Prompter': 'Visual Prompt',
        'LPFT': 'LP-FT',
                  }



regularizations = ['FT', 'HeadOnly', 'lora', 'lwf', 'ewc', 'LPFT', 'Soup-PRE-FT', 'Soup-PRE-FT-ewc-lwf']
#regularizations = ['FT', 'HeadOnly', 'Prompter', 'lora', 'lwf', 'ewc', 'LPFT', 'Soup-PRE-FT', 'Soup-FT-ewc-lwf', 'Soup-PRE-FT-ewc-lwf']
regularizations = ['FT', 'HeadOnly', 'Prompter', 'lora', 'lwf', 'ewc', 'LPFT', 'Soup-PRE-FT', 'Soup-PRE-FT-ewc-lwf']
regularizations = ['FT', 'HeadOnly', 'Prompter', 'lora', 'ewc', 'lwf', 'LPFT', 'Soup-PRE-FT', 'Soup-PRE-FT-ewc-lwf']


order_column = ['imagenet/val', 'imagenet-v2', 'imagenet-a', 'imagenet-r', 'imagenet-sketch', 'objectnet-1.0', 'imagenet-cartoon', 'imagenet-drawing', 'imagenet-c']
order_column = ['imagenet/val', 'imagenet-v2', 'imagenet-a', 'imagenet-r', 'imagenet-sketch', 'objectnet-v2', 'imagenet-cartoon', 'imagenet-drawing', 'imagenet-c']

def convert2latex(regs, numbers):
    text = []
    mris = numbers.mean(1)
    max_val = numbers.max(0)
    for reg, mri, number in zip(regs, mris, numbers):
        _text = f' & {reg} &'
        if mri == max(mris):
            _text = _text + '\\textbf{' + f'{mri:.1f}' + '}'
        else:
            _text += f'{mri:.1f}'
        for i, (n, _max_val) in enumerate(zip(number, max_val)):
            if n == _max_val:
                _text = _text + ' & \\textbf{' + f'{n:.1f}' + '}'
            else:
                _text += f' & {n:.1f}'
        text.append(_text)
    return text


def lines2latex(lines, fname):
    print('Converting to latex', fname)
    with open(fname, 'w') as f:
        for line in lines:
            if 'midrule' not in line and 'hline' not in line and 'cline' not in line:
                line = line + ' \\\\'
            f.write(line + '\n')



def _draw_plot(data, pres, names, colors, ax, ylabel=None):
    ax.plot([0,0,0,0], label='Pre-Trained', marker='.', color='black')
    for i, d in enumerate(data.T):
#        name = name_converter[regularizations[i]]
#        name = name_converter[regularizations[i]].split('~')[0].replace('Model Soup', 'Model Soup(PRE-FT-LwF-EWC)')
        ax.plot(d, label=name, marker='.', color=colors[i], alpha=0.75)
    breakpoint()
    ax.set_xticks(range(len(dnames)))
    ddnames = [name.replace(' +', '\n+') for name in dnames]
    ax.set_xticklabels(ddnames)
    if ylabel:
        ax.set_ylabel(ylabel)


def draw_bar(accs, pre_accs, mris, names, dnames, colors, ylabel):
    """
        Draw bar plot for each method separately
    """
    colors = sns.color_palette("tab10")
    names[-1] = 'Model Soup: PRE-FT-EWC-LwF'
    dnames = [d.replace(' +', '\n+') for d in dnames]
#    dnames[2] = 'IN-21K\n+ AugReg'
    ft_only = True
    if ft_only:
        fig, axes = plt.subplots(1,1, figsize=(5,3), dpi=1000)
        axes = [axes]
    else:
        fig, axes = plt.subplots(3, 3, figsize=(12,9), dpi=1000)
#    axes = axes.flatten()
    mris = mris.T
    width = 0.3
    offset = width/2
    for i, (acc, name, ax, color) in enumerate(zip(accs.T, names, axes, colors)):
#        ax.bar(dnames, acc, color=color)
#        ax.plot([0,1,2,3], pre_accs, color='black')
        x = np.arange(len(pre_accs))
#        ax.bar(dnames, pre_accs[:,0], alpha=0.5, color='blue')
        ax.bar(x - offset, pre_accs[:,0], width=width, color='blue')

        ax.bar(x + offset, acc, width=width, color='red') #, color=colors[:4], yerr=-mris[i])
#        ax.bar(dnames, acc, alpha=0.5, color='red') #, color=colors[:4], yerr=-mris[i])
        if not ft_only:
            ax.set_title(name)
        if i < 6 and False:
            ax.set_xticks([])
            ax.set_xticklabels([])
        else:
#            ax.set_xticks(np.arange(len(x)))
            ax.set_xticks(np.arange(len(x)))
#            ax.set_xticklabels(dnames, fontsize=8, rotation=45)
            ax.set_xticklabels(dnames, fontsize=7)
        ax.set_ylim(20,65)
        if i == 3: # == 0:
            ax.set_ylabel(ylabel, fontsize=14)
    if ft_only:
        axes[0].set_ylabel(ylabel, fontsize=7)
    labels = dnames
    labels = ["Pre-Trained Model", 'Fine-Tuned Model']
    handles = [plt.Rectangle((0,0),1,1, color=c) for c in colors][:len(labels)]
    handles = [plt.Rectangle((0,0),1,1, color=c) for c in ['blue', 'red']][:len(labels)]
    if ft_only:
        plt.figlegend(handles, labels, ncol=4, loc="upper center", prop={'size':7})
    else:
        plt.figlegend(handles, labels, ncol=4, loc="upper center")
    plt.tight_layout(rect=[0,0, 1.0, 0.95])
    plt.savefig('bar.pdf', dpi=1000)



    







def draw_plot(data, dnames, pre_val=None):
    df = pd.read_csv('summary.csv')
    # resnet contained pretrain is all in1k instead of imagenet.
    df.loc[df['arch'] == 'resnet18_base_32','pretrain'] = 'in1k'
    df.loc[df['arch'] == 'resnet50_base_32','pretrain'] = 'in1k'
    # for each model except PRE, get the performance on the column that matches with the 'dataset'
    df = df[df['mode'] == 'PRE']
    df = df[df['arch'] == 'base_16']
    # get pretrain in1k, orig, imagenet, laion
    order = ['in1k', 'sam', 'orig', 'miil', 'imagenet', 'laion']
    imagenet_acc = [df[df['pretrain'] == dpre]['imagenet/val'].item() for dpre in order]
    colors = ['tab:red', 'tab:pink', 'tab:orange', 'gold', 'tab:green', 'tab:olive', 'tab:cyan', 'tab:blue', 'tab:purple']
    data = np.stack(data) # (Ndataset, Nmodel)
    names = []
    for reg in regularizations:
        name = name_converter[reg].split('~')[0].replace('Model Soup', 'MS:PRE-FT-EWC-LwF')
        names.append(name)
    
    # without sam and miil
    # miil 53.3, sam: 42.9
    pre_acc = pre_val.reshape(-1,1)
#    pre_acc = np.asarray([41.8, 42.9, 48.925, 53.3, 56.4, 60.4]).reshape(-1,1)
    acc = pre_acc + data
    ylabel = 'Average Accuracy on OOD Datasets ($mRI + \\frac{1}{n}\sum_i^n A_{pre}^{(i)})$'
    draw_bar(acc, pre_acc, data, names, dnames, colors, ylabel)
    return 
    # line plot with datapoint.
#    fig, axes = plt.subplots(1,2, figsize=(10, 5))
    fig, axes = plt.subplots(1,2, figsize=(10, 4), dpi=1000)
    # Add imagenet_acc on the sub axis.
#    sub_ax = ax.twinx() 
#    sub_ax.bar([0,1,2,3], imagenet_acc) #, marker='o')
#    sub_ax.set_ylabel('ImageNet Accuracy')
#    sub_ax.set_ylim(75,100)
    ax = axes[0]
    ylabel = '$mRI$'
    _draw_plot(data, [0,0,0,0], names, colors, ax, ylabel)
    handles, labels = ax.get_legend_handles_labels()

#    plt.figlegend(handles, labels, loc = 'upper center', ncol=5) #, labelspacing=0.)
    plt.figlegend(handles, labels, loc = 'right', ncol=1) #, labelspacing=0.)
#    ax.get_legend().remove()
    # remove legend in ax.
    #####
    # Average Accuracy
    ax = axes[1]
    # Add imagenet_acc on the sub axis.

    ylabel = 'Average Accuracy on OOD Datasets ($mRI + \\frac{1}{n}\sum_i^n A_{pre}^{(i)})$'
    _draw_plot(acc, pre_acc, names, colors, ax, ylabel)

#    ax.set_ylim(-40, 10)
    # set tick as dataset name.
#    plt.tight_layout(rect=[0, 0, 1.0, 0.91])
    plt.legend()
#    axes[0].get_legend().remove()
    axes[1].get_legend().remove()
#    plt.tight_layout(rect=[0, 0, 0.85, 1.0])
    plt.tight_layout(rect=[0, 0, 0.82, 1.0])

    plt.savefig('mri.pdf')
#    ax.set_zorder(2)


    breakpoint()


def get_pretraining_tex():
    df = pd.read_csv('summary.csv')
    # resnet contained pretrain is all in1k instead of imagenet.
    df.loc[df['arch'] == 'resnet18_base_32','pretrain'] = 'in1k'
    df.loc[df['arch'] == 'resnet50_base_32','pretrain'] = 'in1k'
    # for each model except PRE, get the performance on the column that matches with the 'dataset'
    df = df[df['mode'] == 'PRE']

    pretrained = [pretrained_converter[pre] for pre in df['pretrain']]
    datasets = [dataset_converter[da] for da in df['dataset']]
    archs = []
    for arch in df['arch']:
        if 'resnet' in arch:
            arch = arch.split('_')[0]
        archs.append(arch_converter[arch])
    archs = np.asarray(archs)
    pretraineds = np.asarray(pretrained)
    # get data for dataset columns in the ordered way.
    columns = df.columns
    performance = df[order_column].values
    texts = []
    for l, arch in enumerate(order_arch):
#        max_val = performance[(archs == arch)]
#        max_val = max_val.max(0)
        max_val = performance.max(0)
        # make separate file per arch
        num_arch = (archs == arch).sum()
        head = '\\multirow{' + str(num_arch) + '}{*}{' + arch + '}'
        is_first_pre = True
        for k, pre in enumerate(order_pretrained):
            mask = (archs == arch) & (pre == pretraineds)
            if mask.sum() == 0:
                continue
            subhead = f' & {pre}'
            if is_first_pre:
                is_first_pre = False
                subhead = head + subhead
            line = performance[mask][0]
            text = subhead  #+ text
            print(text)
            for i, n in enumerate(line):
                if n == max_val[i]:
                    text += ' & \\textbf{' + f'{n:.1f}' + '}'
                else:
                    text += ' & ' + f'{n:.1f}'
            texts.append(text)
            if k == len(order_pretrained) - 1:
                texts.append('\\midrule')
#                else:
#                    texts.append('\\cline{2-11}')
    path = f"tex/pre.tex"
    lines2latex(texts, path)

def get_training_tex():
    df = pd.read_csv('summary.csv')
    # resnet contained pretrain is all in1k instead of imagenet.
    df.loc[df['arch'] == 'resnet18_base_32','pretrain'] = 'in1k'
    df.loc[df['arch'] == 'resnet50_base_32','pretrain'] = 'in1k'
    # for each model except PRE, get the performance on the column that matches with the 'dataset'
    df = df[df['mode'] != 'PRE']
    archs = df['arch']
    pretrained = [pretrained_converter[pre] for pre in df['pretrain']]
    datasets = [dataset_converter[da] for da in df['dataset']]
    _archs = []
    for arch in archs:
        if 'resnet' in arch:
            arch = arch.split('_')[0]
        _archs.append(arch_converter[arch])
    archs = np.asarray(_archs)
    pretraineds = np.asarray(pretrained)
    datasets = np.asarray(datasets)
    accs = df['val_acc'].to_numpy()
    regs = df['mode'].to_numpy()
    new_archs = []
    new_pretraineds = []
    new_regs = []
    performance = []
    for arch in order_arch:
        mask = archs == arch 
        for pre in order_pretrained:
            _mask = mask & (pretraineds == pre)
            for reg in regularizations:
                __mask = _mask & (regs == reg)
                _accs = []
                for dataset in order_dataset:
                    acc = accs[__mask & (datasets == dataset)]
                    if len(acc) == 0:
                        acc = -1
                    else:
                        acc = acc.item()
                    _accs.append(acc)
                _accs = np.asarray(_accs)
                if (_accs > -1).sum() > 0: #len(_accs) > 0:
                    new_archs.append(arch)
                    new_pretraineds.append(pre)
                    new_regs.append(reg)
                    performance.append(_accs)
    pretraineds = np.asarray(new_pretraineds)
    archs = np.asarray(new_archs)
    regs = np.asarray(new_regs)
    performance = np.stack(performance)
    arch_groups = [order_arch[:2], order_arch[2:5], order_arch[-2:]]
    for l, arch_group in enumerate(arch_groups):
        texts = []
        for arch in arch_group:
            # make separate file per arch
            num_arch = (archs == arch).sum()
            head = '\\multirow{' + str(num_arch) + '}{*}{' + arch + '}'
            is_first_pre = True
            for k, pre in enumerate(order_pretrained):
                num_pre = ((archs == arch) & (pre == pretraineds)).sum()
                if 'wo' in pre:
                    subhead = '& \\multirow{' + str(num_pre) + '}{*}{\\shortstack[c]{IN-21K\\\\(wo Augreg)}}'
                else:
                    subhead = '& \\multirow{' + str(num_pre) + '}{*}{' + pre + '}'
                max_val = performance[(archs == arch) & (pretraineds == pre)]
                if len(max_val) == 0:
                    continue
                if is_first_pre:
                    is_first_pre = False
                    subhead = head + subhead
                max_val = max_val.max(0)
                is_first = True
                for j, reg in enumerate(regularizations):
                    mask = (archs == arch) & (pretraineds == pre) & (regs == reg)
                    if mask.sum() > 0:
                        line = performance[mask][0]
                        text = "& " + name_converter[reg]
                        if is_first:
                            is_first = False
                            text = subhead + text
                            print(text)
                        else:
                            text = "& " + text
                        for i, n in enumerate(line):
                            if n == max_val[i]:
                                text += ' & \\textbf{' + f'{n:.1f}' + '}'
                            else:
                                text += ' & ' + f'{n:.1f}'
                        texts.append(text)
                if k == len(order_pretrained) - 1:
                    texts.append('\\midrule')
                else:
                    texts.append('\\cline{2-11}')
        path = f"tex/down_{l}.tex"
        lines2latex(texts, path)


def prepare_div_pre():
    df = pd.read_csv('summary.csv')
    # resnet contained pretrain is all in1k instead of imagenet.
    df.loc[df['arch'] == 'resnet18_base_32','pretrain'] = 'in1k'
    df.loc[df['arch'] == 'resnet50_base_32','pretrain'] = 'in1k'
    summary = df[['imagenet/val', 'mode', 'arch', 'pretrain', 'dataset']]

    pretraineds = np.asarray([pretrained_converter[pre] for pre in df['pretrain']])
    arches = np.asarray([arch_converter[a] for a in df['arch']])

    return summary, pretraineds, arches

def get_ri_tex(div_pre=False):

    if div_pre:
        summary, summary_pretraineds, summary_arches = prepare_div_pre()
    
    pre_df = pd.read_csv('summary.csv')
    pre_df = pre_df[pre_df['mode'] == 'PRE']
    avg_pre = pre_df[order_column[1:]].T.mean()
    pretrain = pre_df['pretrain']
    arch = pre_df['arch']
    pre_arch = []
    for a in arch:
        if 'resnet' in a:
            a = a.split('_')[0]
        pre_arch.append(arch_converter[a])
    pre_dataset = np.asarray([pretrained_converter[p] for p in pretrain])
    pre_arch = np.asarray(pre_arch)
    pre_val = avg_pre.values


    df = pd.read_csv('ri.csv')
    df = df.rename(columns={'imagenet_resnet50_base_32': 'in1k_resnet50_base_32',
                            'imagenet_resnet18_base_32': 'in1k_resnet18_base_32'})
    models = df.columns
    models, arches, pretraineds = converter(models)

    texts = []
    arches = np.array(arches)
    pretraineds = np.array(pretraineds)
    models = np.array(models)
    mris = []
    mri_arch = []
    mri_pre = []
    for ip, p in enumerate(order_pretrained):
        n = (pretraineds == p).sum()
#        text = '\\multirow{%d}{*}{%s}' % (np * num_method, p)
        output = []

        for ia, a in enumerate(order_arch):
            mask = (arches == a) & (pretraineds == p)
            if div_pre:
                _summary = summary[(summary_arches == a) & (summary_pretraineds == p)]
                pre_acc = _summary[_summary['mode'] == 'PRE']['imagenet/val'].values
            na = mask.sum()
            if na == 0:
                continue
            model = models[mask][0]
            data = df[['mode', 'dataset', model]]
            numbers = [] 
            regs = []
            mri = []
            for reg in regularizations:
                _data = data[data['mode'] == reg][['dataset', model]]
                if len(_data) == 0:
                    continue
#                if (reg == 'lora' or reg == 'Prompter') and 'resnet' in model:
                if reg == 'lora' and 'resnet' in model:
                    mri.append(-2147483647) # dummy
                    continue
                regs.append(name_converter[reg])
                # reorder number.
                _data = pd.concat([_data[_data['dataset'] == dname] for dname in order_column[1:]])
                number = _data.T.values[1]
                if div_pre:
                    temp = _summary[_summary['mode'] == reg]
                    temp = pd.concat([temp[temp['dataset'] == dname] for dname in order_column[1:]])
                    down_acc = temp['imagenet/val'].values
                    diff = down_acc - pre_acc
                    # prevent from zero-divider
                    diff[diff == 0] = 1e-6

                    number = number / diff

                numbers.append(number.astype(float))
                mri.append(number.mean())
            numbers = np.stack(numbers)
            mris.append(mri)
            mri_arch.append(a)
            mri_pre.append(p)
            lines = convert2latex(regs, numbers)
            text = '\\multirow{%d}{*}{%s}' % (len(lines), a)
            lines[0] = text + lines[0]
            if ia != len(order_arch) - 1:
                lines.append('\\midrule')
            output += lines
        if div_pre:
            path = f'tex/ri_div_pre_{p}.tex'
        else:
            path = f'tex/ri_{p}.tex'
        lines2latex(output, path)
    
    # make mri table.
    mri_arch = np.asarray(mri_arch)
    mri_pre = np.asarray(mri_pre)
    mris = np.asarray(mris)
    # reorder based on mri_arch and then mri_pre using order_arch and order_pretrained
    new_mris = []
    arches = []
    pres = []
    new_pre_val = []
    for a in order_arch:
        mask = mri_arch == a
        pre_mask = pre_arch == a
        if mask.sum() == 0:
            continue
        _new_mris = []
        _pres = []
        _new_pre_val = []
        for ip, p in enumerate(order_pretrained):
            mask2 = mri_pre == p
            pre_mask2 = pre_mask & (pre_dataset == p)
            mask3 = mask & mask2
            if mask3.sum() == 0:
                continue
            _pres.append(p)
            _new_pre_val.append(pre_val[pre_mask2])
            numbers = mris[mask3][0] # (num_reg,)
            _new_mris.append(numbers)
        if len(_new_pre_val) > 0:
            _new_pre_val = np.concatenate(_new_pre_val)
        new_mris.append(_new_mris)
        new_pre_val.append(_new_pre_val)
        if a == 'ViT-B/16':
            draw_plot(_new_mris, _pres, _new_pre_val)
        pres.append(_pres)
        arches.append(a)

    new_mris = np.concatenate(new_mris)
    # write head.
    texts = []
    head = "$\\Dpre$"
    subhead = "Method"
    for a, _pres in zip(arches, pres):
        if a != arches[-1]:
            head += f' & \\multicolumn{{{len(_pres)}}}' + '{c|}{' + a + '}'
        else:
            head += f' & \\multicolumn{{{len(_pres)}}}' + '{c}{' + a + '}'
        subhead += ' & ' + ' & '.join(_pres)
    texts.append(head)
    texts.append('\\hline')
    texts.append(subhead)
    texts.append('\\midrule')
    max_val = new_mris.max(1)
    for reg, line in zip(regularizations, new_mris.T):
#        text = name_converter[reg].split('~')[0] + ' & ' + ' & '.join([f'{l:.1f}' for l in line]) 
        text = name_converter[reg].split('~')[0]
        for i, n in enumerate(line):
            if n == -2147483647:
                text += ' & -'
            elif n == max_val[i]:
                text += ' & \\textbf{' + f'{n:.1f}' + '}'
            else:
                text += ' & ' + f'{n:.1f}'
        text = text.replace('-2147483647.0', '-')
        texts.append(text)
    print(texts)
    if div_pre:
        path = f'tex/mri_div_pre.tex'
    else:
        path = f'tex/mri.tex'
    lines2latex(texts, path)

if __name__ == '__main__':

#    order_pretrained = ['IN-1K', 'IN-21K (wo Augreg)', 'IN-21K', 'LAION-2B']
    order_pretrained = ['IN-1K', 'IN-1K + AugReg', 'IN-1K + SAM', 'IN-21K', 'IN-21K-P', 'IN-21K + AugReg', 'OpenAI', 'LAION-2B']
    order_pretrained = ['IN-1K', 'IN-1K + AugReg', 'IN-21K', 'IN-21K + AugReg', 'OpenAI', 'LAION-2B']
    order_pretrained = ['IN-1K + AugReg', 'IN-21K', 'IN-21K + AugReg', 'OpenAI', 'LAION-2B']
    order_pretrained = ['IN-1K + AugReg', 'IN-1K + SAM', 'IN-21K', 'IN-21K-P', 'IN-21K + AugReg', 'OpenAI', 'LAION-2B']
    order_arch = ['ViT-B/16', 'ViT-B/32', 'ViT-S/16', 'ViT-S/32', 'ViT-L/16', 'ResNet-18', 'ResNet-50']
    order_dataset = ['IN-V2', 'IN-A', 'IN-R', 'IN-S', 'ObjNet', 'IN-Cartoon', 'IN-Drawing', 'IN-C']
#    get_pretraining_tex()
#    get_training_tex()
#    get_ri_tex(div_pre=True)
    get_ri_tex()
    


