__author__ = "Anon"
__version__ = "0.1"
import torch
import torch.nn as nn
import os
from utils import evaluate, get_net
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from datasets import get_test_loader, TransformComposer, get_train_valid_loader
import argparse
import yaml
import json
import numpy as np
from torchvision import transforms
import pathlib
import pandas as pd
from datetime import datetime
from utils import NORM


def out_of_dist_test(args):
    import torch.distributions as tdist
    import torch.utils.data as data_utils
    from utils import CustomTensorDataset
    from sklearn.metrics import roc_auc_score
    for json_file_path in args.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        net = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()

        print('Loaded weights from {}'.format(json_file_path))
        print(config)
        config['DATASET'] = config['DATASET'].replace('kway_', 'inn_')
        #base dataset #Assumes to be stl-10 and ood as LSUN
        if config['DATASET'] in ['inn_stl10', 'inn_cifar10', 'inn_cifar100', 'cifar100' ,'cifar10', 'stl10']:
            test_transform = TransformComposer(['NORM'], dataset=config['DATASET'], inp_size=config['IMG_SIZE']).get_composite(num_bins=256)
        elif config['DATASET'] in ['inn_cars', 'cars', 'inn_opets', 'inn_cub20', 'inn_bmw10', 'inn_cub200', 'cub200', 'cub20', 'bmw10', 'inn_bmw10', 'pets']:
            test_transform = TransformComposer(['RESIZE', 'CCROP', 'NORM'], dataset=config['DATASET'],
                                               re_size= config['RE_SIZE'], inp_size=config['IMG_SIZE']).get_composite(num_bins=256)
        base_loader = get_test_loader(dataset=config['DATASET'], transform=test_transform, batch_size=8)
        #LSUN
        test_transform = TransformComposer(['RESIZE', 'NORM'], dataset=config['DATASET'], inp_size=config['IMG_SIZE']).get_composite(num_bins=256)
        lsun_loader = get_test_loader(dataset='lsun', transform=test_transform, batch_size=8)
        #tiny
        test_transform = TransformComposer(['RESIZE', 'NORM'], dataset=config['DATASET'], inp_size=config['IMG_SIZE']).get_composite(num_bins=256)
        tiny_loader = get_test_loader(dataset='tiny', transform=test_transform, batch_size=8)
        #fmnist
        test_transform = TransformComposer(['RESIZE', 'REPEAT', 'NORM'], dataset=config['DATASET'],
                                           inp_size=config['IMG_SIZE']).get_composite(num_bins=256)
        f_loader = get_test_loader(dataset='f-mnist', transform=test_transform, batch_size=8)
        
        loaders = [(base_loader, config['DATASET']), (lsun_loader, 'LSUN'), (tiny_loader, 'Tiny ImageNet(Rescaled)'), (f_loader, 'F-MNIST')]
        #loaders = [(base_loader, config['DATASET']), (noise_loader, 'Noise')]
        if 'weight' not in config:
            config['weight'] = {'inn':0, 'cls':1}
        if config['weight']['inn'] == 1:
            file_name = 'INN-{}.png'.format(config['num_negatives'])
            out_idx = 0
        else:
            out_idx = 1
            file_name = 'CLS.png'

        net.eval()
        with torch.no_grad():
            label_shape = torch.Size((config['NUM_CLASSES'], config['NUM_CLASSES']))
            y_dist = torch.zeros(label_shape).cuda()
            indicator = torch.LongTensor([i for i in range(config['NUM_CLASSES'])])
            y_dist.scatter_(1, indicator.unsqueeze(1).cuda(device=y_dist.device),
                            torch.cat(config['NUM_CLASSES'] * [torch.FloatTensor([1])]).unsqueeze(1).cuda(
                                device=y_dist.device))
            dataset_confs = {}
            counter = 0
            for loader, test_name in loaders:
                pred_confs = []
                for _, (images, target) in enumerate(loader):
                    images = images.cuda(non_blocking=True)
                    pred_conf = accuracy_inn_v2(net, images, y_dist, target, counter==0, num_classes=config['NUM_CLASSES'], idx=out_idx)
                    pred_confs.extend(pred_conf)
                dataset_confs[test_name] = pred_confs
                print('Finished {}'.format(test_name))
                counter+=1
            fig, ax = plt.subplots()
            for idx, (_, test_name) in enumerate(loaders):
                g = sns.distplot(dataset_confs[test_name], ax=ax, hist=False, label='{}'.format(test_name))
            # ax.set_title("Out-Of-Distribution Evaluation")
            plt.xticks(np.arange(0, 20, 2.5))
            plt.ylabel = None
            plt.xlabel("Predicted Confidence")
            g.set(xlim=(0, 1), xticks=np.arange(0, 1, 0.1))
            # plt.xlim(0, 1.0)
            plt.legend()
            plt.savefig('ood-{}-{}-{}'.format(config['DATASET'], config['ARCH'], file_name))
            plt.close()

        y_pred = dataset_confs[config['DATASET']]
        y_true = [1 for i in range(len(y_pred))]
        for idx, loader in enumerate(loaders):
            if idx>0:
                ood_pred = dataset_confs[loader[1]]
                ood_true = [0 for i in range(len(ood_pred))]
                print(roc_auc_score(y_true+ood_true, y_pred+ood_pred), loader[1])

def accuracy_inn_v2(model, inputs, y_dist, target, is_base, num_classes=1000, idx=0):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        batch_size = inputs.size(0)
        if idx == 0:
            en_input = inputs.repeat_interleave(num_classes, dim=0)
            y_in = y_dist.repeat([batch_size, 1])
            try:
             output, _, _ = model(en_input, y_in)
            except:
                output = model(en_input, y_in)
            probabilities = torch.nn.functional.softmax(output, dim=1)
            output = probabilities[:, 1]
            output = output.view([batch_size, num_classes])
        else:
            try:
               _, output, _ = model(inputs, torch.rand((batch_size, num_classes)).cuda())
            except:
               output= model(inputs)
            output = torch.nn.functional.softmax(output, dim=1)
        confs, preds = torch.max(output.cpu(), dim=1)
        #if is_base:
        #    correct_samples = target.eq(preds)
        #    confs = confs[correct_samples]
    return confs



def run_test(args):
    from utils import validate
    for json_file_path in args.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        net = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()
        print(config)
        print('Loaded weights from {}'.format(json_file_path))
        config['DATASET'] = config['DATASET'].replace('kway', 'inn')
        # base dataset #Assumes to be fmnist else will need adjustment
        if config['DATASET'] in ['cars', 'inn_cars', 'bmw10', 'inn_bmw10', 'cub200', 'inn_cub200', 'inn_opets', 'inn_cub20', 'cub20', 'pets']:
            transforms = ['RESIZE', 'CCROP', 'NORM']
        else:
            transforms = ['NORM']
        test_transform_composer = TransformComposer(transforms=transforms, dataset=config['DATASET'],
                                                    inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])
        test_transform = test_transform_composer.get_composite(num_bins=256)
        if config['NUM_CLASSES'] >= 100:
            test_b_size = 128
        elif config['NUM_CLASSES'] > 25:
            test_b_size = 8
        else:
            test_b_size = 16
        
        if 'weight' not in config:
            config['weight'] = {'inn':0, 'cls':1}
        test_loader = get_test_loader(dataset=config['DATASET'], batch_size=test_b_size, pin_memory=True,
                                  transform=test_transform)
        inn_t1, inn_t5, cls_t1, cls_t5, inn_cf, cls_cf = validate(net, test_loader, config)
        print('InnAcc: {}, ClsAcc:{}'.format(inn_t1, cls_t1))

def vis_emb(args):
    import umap
    for json_file_path in args.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        net = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()
        print(config)
        print('Loaded weights from {}'.format(json_file_path))

        # base dataset #Assumes to be fmnist else will need adjustment
        if config['DATASET'] in ['inn_opets', 'inn_cub20', 'inn_stl10', 'inn_cifar10', 'inn_bmw10']:
            transforms = ['RESIZE', 'CCROP', 'NORM']
        else:
            transforms = ['NORM']
        config['num_negatives'] = 0
        test_transform_composer = TransformComposer(transforms=transforms, dataset=config['DATASET'],
                                                    inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])
        test_transform = test_transform_composer.get_composite(num_bins=256)
        train_transform_composer = TransformComposer(transforms=config['TRANSFORMS'], dataset=config['DATASET'],
                                                    inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])
        train_transform = train_transform_composer.get_composite(num_bins=256)

        loader, _ = get_train_valid_loader(dataset=config['DATASET'], batch_size=config['BATCH_SIZE'],
                                           transform=train_transform, pin_memory=True, opts=config)
        test_loader = get_test_loader(dataset=config['DATASET'], transform=test_transform)

        num_classes = config['NUM_CLASSES']

        embedding_tr_inn = []
        embedding_y_inn = []
        embedding_te_cls = []
        targets_tr = []
        targets_te = []
        net.eval()

        with torch.no_grad():
            for _, data in enumerate(loader):
                images, input_y, _, cls_y = data[0], data[1], data[2], data[3]
                _, _, feat_dict = net(images.cuda(), input_y.cuda())
                embedding_tr_inn.append(feat_dict['inn_feat'].cpu())
                # embedding_y_inn.append(feat_dict['cls_in'].cpu())
                targets_tr.append(cls_y.cpu())

            embedding_tr_inn = torch.cat(embedding_tr_inn, dim=0).cpu()
            targets_tr = torch.cat(targets_tr, dim=0).cpu()

        with torch.no_grad():
            y_dist = torch.zeros([num_classes, num_classes]).cuda()
            indicator = torch.LongTensor([i for i in range(num_classes)]).cuda()
            y_dist.scatter_(1, indicator.unsqueeze(1)
                            , torch.cat(num_classes * [torch.FloatTensor([1])]).unsqueeze(1).cuda())
            rand_inp_imgs = torch.rand([num_classes, 3, config['IMG_SIZE'], config['IMG_SIZE']]).cuda()
            _, _, data = net(rand_inp_imgs, y_dist)
            hidden_ys = data['cls_in'].cpu()

            for iter, (images, ys) in enumerate(test_loader):
                # Correct labels
                y_in = torch.zeros([images.shape[0], num_classes])
                _, _, data = net(images.cuda(), y_in.cuda())
                embedding_te_cls.append(data['cls_feat'])
                targets_te.append(ys)

            embedding_te_cls = torch.cat(embedding_te_cls, dim=0).cpu()
            targets_te = torch.cat(targets_te, dim=0).cpu()
            print(targets_te.shape)
            torch.cuda.empty_cache()

            for cls_idx in range(num_classes):
                class_name = test_loader.dataset.classes[cls_idx]
                print(class_name)
                t_idxs = torch.nonzero(targets_te == cls_idx)
                f_cls = embedding_te_cls[t_idxs].squeeze()
                num_items = f_cls.shape[0]
                print(num_items)
                cls_targets = []
                updated_reps = []
                for t_cls_idx in range(num_classes):
                    y_in = hidden_ys[t_cls_idx].repeat(num_items, 1)
                    cls_targets.extend([t_cls_idx for i in range(num_items)])
                    updated_rep = f_cls * y_in
                    updated_reps.append(updated_rep)
                updated_reps = torch.cat(updated_reps, dim=0).cpu()
                cls_targets = torch.LongTensor(cls_targets)
                print(targets_tr.shape, cls_targets.shape, updated_reps.shape)

                combined_reps = torch.cat([embedding_tr_inn, updated_reps], dim=0).cpu()
                cls_targets = torch.cat([targets_tr.unsqueeze(1), cls_targets.unsqueeze(1)], dim=0).cpu().squeeze()
                print(combined_reps.shape, cls_targets.shape)
                reducer = umap.UMAP(random_state=42)
                data = reducer.fit_transform(combined_reps.numpy(), cls_targets.numpy())

                # te_data = reducer.transform(updated_reps)
                num_te_samples = updated_reps.shape[0]
                num_tr_samples = embedding_tr_inn.shape[0]
                plt.figure(figsize=(8, 8), facecolor="grey")
                plt.scatter(data[:num_tr_samples, 0], data[:num_tr_samples, 1], c=cls_targets.numpy()[:num_tr_samples], cmap='Spectral', s=150, marker='o', alpha=0.1)
                selected_samples = np.asarray([(data[num_tr_samples+x, 0], data[num_tr_samples+x, 1], cls_targets.numpy()[num_tr_samples+x]) for x in range(0, num_te_samples, num_te_samples//num_classes)])
                plt.scatter(selected_samples[:, 0], selected_samples[:, 1], c=selected_samples[:, 2], cmap='Spectral', s=200, marker='*', alpha=0.8)
                plt.gca().set_aspect('equal', 'datalim')
                plt.colorbar(boundaries=np.arange(config['NUM_CLASSES'] + 1) - 0.5).set_ticks(
                    np.arange(config['NUM_CLASSES']))
                plt.xticks([])
                plt.yticks([])
                plt.tight_layout()
                plt.savefig('imgs/{}_{}_{}.png'.format(config['DATASET'], cls_idx, class_name))
                plt.close()
                exit()


def label_label_l2(args):
    for json_file_path in args.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        net = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()
        print(config)
        print('Loaded weights from {}'.format(json_file_path))

        # assumes dataset to be cifar10 or stl10
        transforms = ['NORM']
        config['num_negatives'] = 0

        loader, _ = get_train_valid_loader(dataset=config['DATASET'], batch_size=config['BATCH_SIZE'],
                                           transform=None, pin_memory=True, opts=config)
        num_classes = config['NUM_CLASSES']

        embedding_tr_inn = []
        targets_tr = []
        net.eval()

        with torch.no_grad():

            y_dist = torch.zeros([num_classes, num_classes]).cuda()
            indicator = torch.LongTensor([i for i in range(num_classes)]).cuda()
            y_dist.scatter_(1, indicator.unsqueeze(1)
                            , torch.cat(num_classes * [torch.FloatTensor([1])]).unsqueeze(1).cuda())
            rand_inp_imgs = torch.rand([num_classes, 3, config['IMG_SIZE'], config['IMG_SIZE']]).cuda()
            _, _, data = net(rand_inp_imgs, y_dist)
            hidden_ys = data['cls_in'].cpu()

        #dot_score = torch.mm(hidden_ys, hidden_ys.transpose(0, 1))
        l2_score = -1*torch.cdist(hidden_ys, hidden_ys, p=2)
        print(l2_score.shape, hidden_ys.shape)

        for i in range(num_classes):
            val, topk = torch.topk(l2_score[i], 2)
            names = [(loader.dataset.classes[j], val[k]) for k, j in enumerate(topk)]
            print('{}: {}'.format(loader.dataset.classes[i], names))

        g = sns.heatmap(dot_score, linewidths=0.5, xticklabels=loader.dataset.classes,
                        yticklabels=loader.dataset.classes)
        g.set_xticklabels(g.get_xticklabels(), rotation=45)
        g.figure.tight_layout()
        plt.savefig('imgs/{}_{}.png'.format(config['DATASET'], 'lbl_lbl_l2'))
        plt.close()


class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
    def remove(self): self.hook.remove()

def cam(args):
    from torchvision.transforms import Compose, Resize, CenterCrop, ToPILImage, Normalize, ToTensor
    from utils import NORM
    from PIL import Image
    import skimage.transform

    for json_file_path in args.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        net = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()
        print(config)
        print('Loaded weights from {}'.format(json_file_path))

        config['num_negatives'] = 0
        proxy_dset = config['DATASET'].replace('kway', 'inn')
        num_classes = config['NUM_CLASSES']
        net.eval()
        if config['DATASET'] in ['inn_stl10', 'stl10']:
            imgs = ['imgs/bird_3_1.png']
            test_transform = Compose([ToTensor(), Normalize(mean=NORM[proxy_dset]['256'][0]
                                                 , std=NORM[proxy_dset]['256'][1])])
            lname = 'conv5_x'
            sz = (96,96)
            sz = (300,300)
        else:
            raise NotImplementedError('Dataset not supported')
        for img_path in imgs:
            conv_layer = net._modules.get(lname)
            #activated_features = SaveFeatures(conv_layer)

            img = Image.open(img_path)
            label_idx = int(os.path.basename(img_path).split('.')[0].split('_')[-1])
            img = test_transform(img).unsqueeze(0).cuda()
            y_in = np.zeros((1, num_classes), dtype=int)
            y_in[0][label_idx] = 1
            y_in = torch.FloatTensor(y_in).cuda()
            try:
                out_inn, out_clf, data = net(img, y_in, hook_layer=True)
            except:
                if config['weight']['cls'] == 1:
                    pred = net(img, hook_layer=True)
                    print(pred)
                else:
                    pred = net(img, y_in, hook_layer=True)
                    label_idx = 1
                    print(pred)
            print(pred.shape)
            pred[:, label_idx].backward()
            gradients = net.get_activations_gradient()
            pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])
            activations = net.get_activations(img).detach()
            for i in range(512):
                activations[:, i, :, :] *= pooled_gradients[i]
            heatmap = torch.mean(activations, dim=1).squeeze().cpu()
            heatmap = np.maximum(heatmap, 0)
            heatmap /= torch.max(heatmap)
            overlay = heatmap.squeeze()
            img = Image.open(img_path).resize(sz)
            fig = plt.figure()
            plt.imshow(img)
            plt.imshow(skimage.transform.resize(overlay, sz), alpha=0.5, cmap='jet')
            plt.rcParams["axes.grid"] = False
            plt.xticks([])
            plt.yticks([])
            fig.subplots_adjust(bottom=0)
            fig.subplots_adjust(top=1)
            fig.subplots_adjust(right=1)
            fig.subplots_adjust(left=0)
            plt.savefig('imgs/{}_cls-{}_act-{}_cam_{}'.format(config['DATASET'], config['weight']['cls'], config['class_in_activation'],
                                                                   os.path.basename(img_path)), bbox_inches='tight', pad_inches=0)
            plt.close()


def var_y_eval(opts):
    for json_file_path in opts.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        model = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()
        print(config)
        print('Loaded weights from {}'.format(json_file_path))

        #transform is for stl-10 and cifar-10, for others needs to be updated accordingly
        test_transform = TransformComposer(transforms=['NORM'], dataset=config['DATASET'],
                                           inp_size=config['IMG_SIZE']).get_composite(num_bins=256)
        test_loader = get_test_loader(dataset=config['DATASET'], batch_size=config['BATCH_SIZE'], pin_memory=True,
                                      transform=test_transform)
        num_classes = config['NUM_CLASSES']
        correct_acc = 0
        incorrect_acc = 0
        ones_acc = 0
        zeros_acc = 0
        model.eval()

        def acc(model, input, target, y_dist):
            """Computes the accuracy over the k top predictions for the specified values of k"""
            model.eval()
            with torch.no_grad():
                output = model(input.cuda(), y_dist.cuda())
                vals, preds = output.max(1)
                correct = preds.eq(target.cuda()).sum()

            return correct

        with torch.no_grad():
            for iter, (images, ys) in enumerate(test_loader):
                target = torch.ones([images.shape[0]]).long()
                #correct
                y_c = torch.zeros([images.shape[0], num_classes])
                y_c.scatter_(1, ys.unsqueeze(1), 1)
                correct_acc += acc(model, images, target, y_c)
                #incorrect
                y_i = torch.zeros([images.shape[0], num_classes])
                random_label_add = torch.randint(1, num_classes, (images.shape[0],))
                y_incorrect = (ys + random_label_add) % num_classes
                y_i.scatter_(1, y_incorrect.unsqueeze(1), 1)
                incorrect_acc += acc(model, images, target, y_i)
                #ones
                y_o = torch.ones([images.shape[0], num_classes])
                ones_acc +=  acc(model, images, target, y_o)
                #zeros
                y_z = torch.zeros([images.shape[0], num_classes])
                zeros_acc +=  acc(model, images, target, y_z)
        n_images = float(len(test_loader.dataset))
        print('C:{}, I:{}, 0:{}, 1:{}'.format(float(correct_acc)/n_images, float(incorrect_acc)/n_images, float(ones_acc)/n_images, float(zeros_acc)/n_images))

def clf_zee(opts):
    from sklearn.linear_model import SGDClassifier
    for json_file_path in opts.base_paths:
        with open(json_file_path) as json_file:
            config = json.load(json_file)
        base_dir = os.path.dirname(json_file_path)
        net_weight_path = os.path.join(base_dir, '256-best.pth')
        net = get_net(config['ARCH'], config['NUM_CLASSES'], False, net_weight_path, config).cuda()
        print(config)
        print('Loaded weights from {}'.format(json_file_path))

        # base dataset #Assumes to be fmnist else will need adjustment
        if config['DATASET'] in ['inn_bmw10', 'inn_opets', 'inn_cub20']:
            test_transforms = ['RESIZE', 'CCROP', 'NORM']
        else:
            test_transforms = ['NORM']
        test_transform_composer = TransformComposer(transforms=test_transforms, dataset=config['DATASET'],
                                                    inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])
        test_transform = test_transform_composer.get_composite(num_bins=256)
        train_transform_composer = TransformComposer(transforms=config['TRANSFORMS'], dataset=config['DATASET'],
                                                     inp_size=config['IMG_SIZE'], re_size=config['RE_SIZE'])
        train_transform = train_transform_composer.get_composite(num_bins=256)

        loader, _ = get_train_valid_loader(dataset=config['DATASET'], batch_size=config['BATCH_SIZE'],
                                           transform=train_transform, pin_memory=True, opts=config)

        test_loader = get_test_loader(dataset=config['DATASET'], batch_size=config['BATCH_SIZE'], transform=test_transform)
        num_classes = config['NUM_CLASSES']

        embedding_cls = []
        tr_targets = []
        net.eval()

        with torch.no_grad():
            for _, data in enumerate(loader):
                images, _, _, cls_y = data[0], data[1], data[2], data[3]
                feats = net.features(images.cuda())
                embedding_cls.append(feats.cpu())
                tr_targets.append(cls_y.cpu())

        tr_embedding_cls = torch.cat(embedding_cls, dim=0).cpu().numpy()
        tr_targets = torch.cat(tr_targets, dim=0).cpu().numpy()

        with torch.no_grad():
            te_features = []
            te_targets = []
            for iter, (images, ys) in enumerate(test_loader):
                feats = net.features(images.cuda())
                te_features.append(feats)
                te_targets.append(ys)
        te_features = torch.cat(te_features, dim=0).cpu().numpy()
        te_targets = torch.cat(te_targets, dim=0).cpu().numpy()

        #fit lienar model
        lin_model = SGDClassifier(loss='log', random_state=42, verbose=0, tol=1e-5)
        lin_model.fit(tr_embedding_cls, tr_targets)

        print(lin_model.score(te_features, te_targets))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Generating result figures')
    parser.add_argument('--base_paths', '-b', nargs='+', help='Path to config files', required=True)
    parser.add_argument('--type', '-t', help='Type of experiment', required=True, choices=['ood', 'eval', 'vis_emb', 'zee', 'var_y', 'll_l2', 'cam'])

    args, unknown = parser.parse_known_args()
    if args.type == 'ood':
        out_of_dist_test(args)
    if args.type == 'eval':
        run_test(args)
    if args.type == 'vis_emb':
        vis_emb(args)
    if args.type == 'll_l2':
        label_label_l2(args)
    if args.type == 'cam':
        cam(args)
    if args.type == 'var_y':
        var_y_eval(args)
    if args.type == 'zee':
        clf_zee(args)

