import torch
import aim
import os, argparse, joblib, shutil
import numpy as np
import torch.nn.functional as F
from torch import nn
from time import time
from datetime import datetime
from options import parse_args
from superloss import SuperLoss
from datasets import load_from_disk
from spl import SPL
from mentornet import MentorNet
from ent_curr import EntropyCurriculum
from diff_pred_weighting import DPWeighting
from torch.utils.data import DataLoader
from data import MyDataset, get_dataloader
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoModel, AutoTokenizer, AdamW, logging, AutoConfig
logging.set_verbosity_error()

mean = lambda l: sum(l)/len(l) if len(l) > 0 else 0

def get_dataloaders(args):
    data_dir = os.path.join(args.data_dir, args.data)
    data = load_from_disk(data_dir)

    if args.lng is not None:
        def process(sample):
            sample['entropy_class'] = sample[args.lng]
            return sample
        data = data.map(process)

    train_dataset = data['train']

    n = len(train_dataset)
    if args.data_fraction < 1:
        ids = np.random.choice(n, int(args.data_fraction*n), replace=False)
        train_dataset = train_dataset.select(ids)

    if args.noise > 0:
        noisy_ids = np.random.choice(n, int(args.noise*n), replace=False)
        noisy_labels = {idx: l for idx,l in zip(noisy_ids,
            np.random.permutation(train_dataset[noisy_ids]['label']))}
        def process(sample, idx):
            if idx in noisy_ids:
                sample['label'] = noisy_labels[idx]
            return sample
        train_dataset = train_dataset.map(process, with_indices = True)


    dev_dataset = data['dev']
    test_dataset = data['test']
    columns = ['diff', 'label', 'entropy_class']
    if 'snli' in data_dir:
        columns += ['sentence1', 'sentence2']
    else:
        columns += ['t']

    if 'feature_set1' in data['train'].column_names:
        columns.append('feature_set1')

    if 'ins_weight' in data['train'].column_names:
        columns.append('ins_weight')

    if 'loss_class' in data['train'].column_names:
        columns.append('loss_class')

    train_dataset.set_format(type=None,
            columns=columns)
    dev_dataset.set_format(type=None,
            columns=columns)
    test_dataset.set_format(type=None,
            columns=columns)

    train_dataloader = DataLoader(train_dataset, args.batch_size, True)
    dev_dataloader = DataLoader(dev_dataset, args.batch_size)
    test_dataloader = DataLoader(test_dataset, 1)

    if 'snli' in data:
        test_samples = np.load('test_samples.npz')
        test_samples_data = {}
        for group in ['easy', 'med', 'hard']:
            batch = [test_dataset[idx] for idx in test_samples[group].tolist()]
            res = {k: torch.tensor([sample[k] for sample in batch]) if 'sentence' not in k
                        and k != 't'
                    else [sample[k] for sample in batch] 
                    for k in columns}
            test_samples_data[group] = res
    else:
        test_samples_data = None

    return train_dataloader, dev_dataloader, test_dataloader, test_samples_data, \
            train_dataset, dev_dataset

class Model(nn.Module):
    def __init__(self, args):
        super().__init__()

        model_name, num_labels, ckpt = args.model_name, args.num_labels, args.ckpt
        self.aux_ent = args.aux_ent

        if ckpt is not None:
            self.backbone = AutoModel.from_pretrained(ckpt.replace('meta', 'model'))
        else:
            self.backbone = AutoModel.from_pretrained(model_name)

        config = AutoConfig.from_pretrained(model_name)
        self.config = config
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.classifier.apply(self.init_weights)
        
        if args.aux_ent:
            self.ent_classifier = nn.Linear(config.hidden_size, args.num_ent_labels)
            self.ent_classifier.apply(self.init_weights)
        elif args.aux_ent2:
            self.ent_classifier = nn.Sequential(
                    nn.BatchNorm1d(3),
                    nn.Linear(3, 256),
                    nn.ReLU(),
                    nn.Linear(256, 256),
                    nn.ReLU(),
                    nn.Linear(256, args.num_ent_labels),
                    )
            self.ent_classifier.apply(self.init_weights)
            self.register_buffer('loss_avg', None)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            module.bias.data.zero_()

    def forward(self, x, fs1 = None):
        model_out = self.backbone(**x)
        sent_emb = model_out[1]
        sent_emb = self.dropout(sent_emb)

        logits = self.classifier(sent_emb)

        if self.aux_ent:
            # if isinstance(fs1, list):
            #     fs1 = torch.stack(fs1,1)
            # fs1 = fs1.float().to(sent_emb.device)
            # ent_logits = self.ent_classifier(torch.cat([sent_emb, fs1], -1))
            ent_logits = self.ent_classifier(sent_emb)
            return logits, ent_logits
        else:
            return logits

def init_model(args, device, ent_cfg=None):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if args.ckpt:
        print('[Resuming]')
        state = torch.load(args.ckpt)
        step = state['step']
        name = os.path.basename(args.ckpt)
        str_end = name.rfind('_', 0, -8)
        name = name[:str_end]
        model = Model(args).to(device)

        if args.curr == 'sl':
            curr = SuperLoss(mode=args.sl_mode, lam=args.sl_lam).to(device)
            # curr.load_state_dict(state['curr'])
        elif args.curr == 'ent' or args.curr == 'loss':
            curr = EntropyCurriculum(args.ent_cfg, args.epochs, cfg=ent_cfg)
        elif args.curr == 'ent+':
            curr = EntropyCurriculum(args.ent_cfg, args.epochs, avgloss = True, cfg=ent_cfg)
        elif args.curr == 'spl':
            curr = SPL(mode = args.spl_mode)
        elif args.curr == 'mentornet':
            curr = MentorNet(args.num_labels, args.epochs).to(device)
        elif args.curr == 'dp':
            curr = DPWeighting(args.dp_tao, args.dp_alpha)
        else:
            curr = None
    else:
        step = 0
        name = datetime.now().strftime('%b%d_%H-%M-%S')
        if args.curr == 'sl':
            name += '_sl_%s'%args.sl_mode 
        elif args.curr == 'ent' or args.curr == 'ent+':
            name += '_ent_%s'%args.ent_cfg
        elif args.curr == 'loss':
            name += '_loss_%s'%args.ent_cfg
        elif args.curr == 'spl':
            name += '_spl_%s'%args.spl_mode
        elif args.curr == 'mentornet':
            name += '_mentornet'
        elif args.curr == 'dp':
            name += '_dp'
        model = Model(args).to(device)

        if args.curr == 'sl':
            curr = SuperLoss(mode=args.sl_mode, lam=args.sl_lam).to(device)
        elif args.curr == 'ent' or args.curr == 'loss':
            curr = EntropyCurriculum(args.ent_cfg, args.epochs, cfg=ent_cfg)
        elif args.curr == 'ent+':
            curr = EntropyCurriculum(args.ent_cfg, args.epochs, avgloss = True, cfg=ent_cfg)
        elif args.curr == 'spl':
            curr = SPL(mode = args.spl_mode)
        elif args.curr == 'mentornet':
            curr = MentorNet(args.num_labels, args.epochs).to(device)
        elif args.curr == 'dp':
            curr = DPWeighting(args.dp_tao, args.dp_alpha)
        else:
            curr = None

    return model, tokenizer, curr, name, step

def init_opt(model, args):
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)

    if args.ckpt:
        state = torch.load(args.ckpt)
        optimizer.load_state_dict(state['optimizer'])
    return optimizer

def test(model, test_dataloader):
    print("[Testing]")
    acc = evaluate(model, test_dataloader)[2]
    print("Acc:", acc)
    return acc

class Trainer():
    def __init__(self, model, tokenizer, crit, optimizer, curr, epochs,
            writer, name, step, epoch_size, debug, device, args):
        self.model = model
        self.tokenizer = tokenizer
        self.crit = crit
        self.optimizer = optimizer
        self.curr = curr
        self.writer = writer
        self.name = name
        self.step = step
        self.epoch_size = epoch_size
        self.debug = debug
        self.best_acc = 0
        self.best_step = None
        self.device = device
        self.args = args
        self.epochs = epochs
        self.total_steps = epochs * epoch_size
        self.save_losses = args.save_losses
        if self.save_losses:
            self.losses = {'train': [], 'dev': []}

    def get_loss(self, batch):
        if 'sentence1' in batch:
            x = self.tokenizer(batch['sentence1'], batch['sentence2'], padding=True,
                    truncation=True, return_tensors="pt").to(self.device)
        elif 't' in batch:
            x = self.tokenizer(batch['t'], padding=True,
                    truncation=True, return_tensors="pt").to(self.device)
        if self.args.aux_ent:
            logits, ent_logits = self.model(x)
        else:
            logits = self.model(x)

        labels = batch['label'].to(self.device)

        loss_unw = self.crit(logits, labels)

        relative_epoch = self.step/self.epoch_size - self.args.burn_in

        if self.args.aux_ent2:
            if self.model.loss_avg is None:
                loss_avg = loss_unw.mean()
                self.model.loss_avg = loss_avg.detach()
            else:
                loss_avg = 0.9 * self.model.loss_avg + 0.1 * loss_unw.mean()
                self.model.loss_avg = loss_avg.detach()

            loss_diff = loss_unw - loss_avg
            epoch = torch.ones_like(loss_unw) * relative_epoch
            if self.args.detach_loss:
                loss_and_diff = torch.stack([loss_unw, loss_diff, epoch], -1).detach()
            else:
                loss_and_diff = torch.stack([loss_unw, loss_diff, epoch], -1)
            ent_logits = model.ent_classifier(loss_and_diff)

        if relative_epoch >= 0:
            if self.args.feed_ent:
                if self.args.soft_ent:
                    entropy_class = F.softmax(ent_logits, 1)
                else:
                    entropy_class = ent_logits.argmax(1)
            else:
                entropy_class = batch['entropy_class']

            if self.curr is not None:
                if self.args.soft_ent and self.args.curr == 'ent':
                    confs = entropy_class[:,0] * self.curr(loss_unw,
                            relative_epoch, torch.ones_like(loss_unw) * 0) \
                            + entropy_class[:,1] * self.curr(loss_unw,
                                    relative_epoch, torch.ones_like(loss_unw) * 1)\
                            + entropy_class[:,2] * self.curr(loss_unw,
                                    relative_epoch, torch.ones_like(loss_unw) * 2)
                elif self.args.curr == 'mentornet':
                    confs = self.curr(loss_unw, labels, self.step // self.epoch_size)
                elif self.args.curr == 'dp':
                    confs = self.curr(loss_unw, batch['diff'].to(self.device))
                elif self.args.curr == 'loss':
                    confs = self.curr(loss_unw, self.step / self.total_steps, batch['loss_class'])
                else:
                    confs = self.curr(loss_unw, self.step / self.total_steps, entropy_class)
            else:
                confs = torch.ones_like(loss_unw)
        else:
            confs = torch.ones_like(loss_unw)

        if 'ins_weight' in batch:
            weights = batch['ins_weight'].to(self.device)
        else:
            weights = torch.ones_like(loss_unw)

        eps = 1e-5
        if self.args.balance_logits:
            loss_w = confs * weights * loss_unw
            total_loss = loss_w.sum() / max(eps, (confs*weights).sum())
        else:
            loss_w = confs * loss_unw
            total_loss = loss_w.sum() / max(eps, confs.sum())

        if self.args.aux_ent or self.args.aux_ent2:
            if relative_epoch > 0 and not self.args.soft_ent:
                loss_aux = self.crit(ent_logits, batch['entropy_class'].to(self.device))
                if self.args.balance_aux_ent:
                    total_loss += self.args.ent_alpha*(weights*loss_aux).mean()
                else:
                    total_loss += self.args.ent_alpha*loss_aux.mean()
            return (logits, ent_logits), confs, total_loss, loss_unw.mean()
        else:
            return logits, confs, total_loss, loss_unw.mean()

    def eval_test_samples(self, batch):
        with torch.no_grad():
            _, confs, loss_w, loss_unw = self.get_loss(batch)
        if confs is not None:
            confs = confs.mean()
        return loss_w, loss_unw, confs

    def evaluate(self, dataloader, count=None, return_pred = False, return_loss = False):
        losses = []
        losses_unw = []
        accs = []
        ent_accs = []
        self.model.eval()
        acc_class = [[] for i in range(3)]
        confs = [[] for i in range(3)]
        ent_acc_class = [[] for i in range(3)]
        ent_true, ent_pred = [], []
        trues, preds = [], []
        full_loss = []
        for i, batch in enumerate(dataloader):
            with torch.no_grad():
                logits, conf, loss, loss_unw = self.get_loss(batch)
            if self.args.aux_ent or self.args.aux_ent2:
                logits, ent_logits = logits
                true = batch['entropy_class']
                pred = ent_logits.argmax(-1).cpu()
                ent_acc = accuracy_score(true,
                        pred)
                ent_accs.append(ent_acc)
                ent_true.extend(true.tolist())
                ent_pred.extend(pred.tolist())
            true = batch['label']
            pred = logits.argmax(-1).cpu()
            if return_loss:
                full_loss += self.crit(logits, batch['label'].to(logits.device)).tolist()
            acc = accuracy_score(true, pred)
            trues.extend(true.tolist())
            preds.extend(pred.tolist())
            losses_unw.append(loss_unw.detach().item())
            losses.append(loss.detach().item())
            accs.append(acc)
            if torch.numel(batch['label']) == 1:
                acc_class[batch['entropy_class'].item()].append(acc)
                if self.args.aux_ent or self.args.aux_ent2:
                    ent_acc_class[batch['entropy_class'].item()].append(ent_acc)
            if count and i > 0 and i % count == 0:
                break
            conf_easy = conf[batch['entropy_class'] == 0]
            conf_med = conf[batch['entropy_class'] == 1]
            conf_hard = conf[batch['entropy_class'] == 2]
            if conf_easy.numel() != 0:
                confs[0].append(conf_easy.mean().item())
            if conf_med.numel() != 0:
                confs[1].append(conf_med.mean().item())
            if conf_hard.numel() != 0:
                confs[2].append(conf_hard.mean().item())

        self.model.train()

        f1 = f1_score(trues, preds, average = 'macro')
        if len(ent_accs) > 0:
            ent_f1 = f1_score(ent_true, ent_pred, average='macro')
            res = [mean(losses_unw), mean(losses), mean(accs), confs, f1, mean(ent_accs), ent_f1,\
                    [mean(a) for a in acc_class], [mean(a) for a in ent_acc_class]]
        else:
            res = [mean(losses_unw), mean(losses), mean(accs), confs, f1, [mean(a) for a in acc_class]]

        if return_pred:
            res.append(preds)
        if return_loss:
            res.append(full_loss)

        return res

    def save(self):
        self.model.backbone.save_pretrained('{}/{}_best_model'.format(self.args.ckpt_dir,
            self.name))
        torch.save({
            'model': {k: v for k,v in self.model.state_dict().items() if 'backbone' not in k},
            'optimizer': self.optimizer.state_dict(),
            'step': self.step,
            'curr': self.curr,
            'best_step': self.best_step,
            'best_acc': self.best_acc},
            '{}/{}_best_meta.pt'.format(self.args.ckpt_dir, self.name))

    def load(self, name):
        state = torch.load("%s_meta.pt"%name)
        self.model.load_state_dict(state['model'], strict=False)
        self.step = state['step']
        self.curr = state['curr']
        self.optimizer.load_state_dict(state['optimizer'])
        self.best_step = state['best_step']
        self.best_acc = state['best_acc']

        self.model.backbone = AutoModel.from_pretrained("%s_model"%name).to(self.device)

    def load_best(self):
        if self.best_step:
            print("[Loading Best] Current: %d -> Best: %d (%.4f)"%(self.step, self.best_step,
                self.best_acc))
        self.load('{}/{}_best'.format(self.args.ckpt_dir, self.name))
        return self.best_acc, self.best_step

    def cleanup(self):
        path = '{}/{}_best'.format(self.args.ckpt_dir, self.name)
        os.remove("%s_meta.pt"%path)
        shutil.rmtree("%s_model"%path)

    def train(self, train_dataloader, dev_dataloader, test_samples, train_ns = None):
        for e in range(self.epochs):
            self.model.train()
            for batch in train_dataloader:
                logits, conf, loss, loss_unw = self.get_loss(batch)
                if self.args.aux_ent or self.args.aux_ent2:
                    logits, ent_logits = logits
                    ent_acc = accuracy_score(batch['entropy_class'],
                            ent_logits.argmax(-1).cpu())
                    ent_f1 = f1_score(batch['entropy_class'],
                            ent_logits.argmax(-1).cpu(), average = 'macro')
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                acc = accuracy_score(batch['label'], logits.argmax(-1).cpu())
                f1 = f1_score(batch['label'], logits.argmax(-1).cpu(), average = 'macro')
                if self.writer is not None:
                    self.writer.track(loss.detach().item(),
                            name = 'loss_weighted', dataset = 'train')
                    self.writer.track(loss_unw.detach().item(),
                            name = 'loss_unweighted', dataset = 'train')
                    self.writer.track(acc,
                            name = 'acc', dataset = 'train')
                    self.writer.track(f1,
                            name = 'f1', dataset = 'train')
                    if self.args.aux_ent or self.args.aux_ent2:
                        self.writer.track(ent_acc,
                                name = 'ent_acc', dataset = 'train')
                        self.writer.track(ent_f1,
                                name = 'ent_f1', dataset = 'train')
                    conf_easy = conf[batch['entropy_class'] == 0]
                    conf_med = conf[batch['entropy_class'] == 1]
                    conf_hard = conf[batch['entropy_class'] == 2]
                    if conf_easy.numel() != 0:
                        self.writer.track(conf_easy.mean().item(),
                                name = 'conf', dataset = 'train', subset = 'easy')
                    if conf_med.numel() != 0:
                        self.writer.track(conf_med.mean().item(),
                                name = 'conf', dataset = 'train', subset = 'med')
                    if conf_hard.numel() != 0:
                        self.writer.track(conf_hard.mean().item(),
                                name = 'conf', dataset = 'train', subset = 'hard')
                if test_samples is not None\
                        and self.writer is not None\
                        and (self.step + 1) % (self.epoch_size // 3) == 0:
                    easy = self.eval_test_samples(test_samples['easy'])
                    med = self.eval_test_samples(test_samples['med'])
                    hard = self.eval_test_samples(test_samples['hard'])

                    self.writer.track(easy[0].item(), name = 'loss_weighted',
                            dataset = 'test-vis', subset = 'easy')
                    self.writer.track(med[0].item(), name = 'loss_weighted',
                            dataset = 'test-vis', subset = 'med')
                    self.writer.track(hard[0].item(), name = 'loss_weighted',
                            dataset = 'test-vis', subset = 'hard')
                    self.writer.track(easy[1].item(), name = 'loss_unweighted',
                            dataset = 'test-vis', subset = 'easy')
                    self.writer.track(med[1].item(), name = 'loss_unweighted',
                            dataset = 'test-vis', subset = 'med')
                    self.writer.track(hard[1].item(), name = 'loss_unweighted',
                            dataset = 'test-vis', subset = 'hard')
                    if easy[2] is not None:
                        self.writer.track(easy[2].item(), name = 'conf',
                                dataset = 'test-vis', subset = 'easy')
                        self.writer.track(med[2].item(), name = 'conf',
                                dataset = 'test-vis', subset = 'med')
                        self.writer.track(hard[2].item(), name = 'conf',
                                dataset = 'test-vis', subset = 'hard')
                    self.model.train()

                if (self.step + 1) % (self.epoch_size // 2) == 0:
                    res = self.evaluate(dev_dataloader, return_loss = self.save_losses)
                    if self.args.aux_ent or self.args.aux_ent2:
                        loss_unw, loss, acc, conf, f1, ent_acc, ent_f1, _, _ = res[:9]
                    else: 
                        loss_unw, loss, acc, conf, f1, _ = res[:6]
                    if self.save_losses:
                        res_train = self.evaluate(train_ns, return_loss = True)
                        self.losses['train'].append(res_train[-1])
                        self.losses['dev'].append(res[-1])

                    if self.writer is not None:
                        self.writer.track(loss_unw, name = 'loss_unweighted',
                                dataset = 'val')
                        self.writer.track(loss, name = 'loss_weighted', 
                                dataset = 'val')
                        self.writer.track(acc, name = 'acc', 
                                dataset = 'val')
                        self.writer.track(f1, name = 'f1', 
                                dataset = 'val')
                        if self.args.aux_ent or self.args.aux_ent2:
                            self.writer.track(ent_acc, name = 'ent_acc', 
                                    dataset = 'val')
                            self.writer.track(ent_f1, name = 'ent_f1', 
                                    dataset = 'val')
                        self.writer.track(mean(conf[0]),
                                name = 'conf', dataset = 'val', subset = 'easy')
                        self.writer.track(mean(conf[1]),
                                name = 'conf', dataset = 'val', subset = 'med')
                        self.writer.track(mean(conf[2]),
                                name = 'conf', dataset = 'val', subset = 'hard')
                    if acc > self.best_acc:
                        self.best_acc = acc
                        self.best_step = self.step
                        self.save()

                self.step += 1

if __name__ == '__main__':
    args = parse_args()

    device = torch.device('cuda:%d'%args.gpu if torch.cuda.is_available() else 'cpu')
    train_dataloader, dev_dataloader, test_dataloader, test_samples,\
            train_dataset, _ = get_dataloaders(args)

    if args.curr == 'dp':
        args.dp_tao = np.percentile(train_dataset['diff'], 50)

    for seed in args.seed:
        torch.manual_seed(seed)
        np.random.seed(seed)

        model, tokenizer, curr, name, step = init_model(args, device)
        optimizer = init_opt(model, args)
        crit = nn.CrossEntropyLoss(reduction='none')

        writer = aim.Session(experiment=args.aim_exp, system_tracking_interval=0) if not args.debug else None

        epoch_size = len(train_dataloader)
        trainer = Trainer(model, tokenizer, crit, optimizer, curr, args.epochs,
                writer, name, step, epoch_size, args.debug, device, args)

        if not args.eval_only:
            print('[Starting Training]')
            trainer.train(train_dataloader, dev_dataloader, test_samples,
                    DataLoader(train_dataset, args.batch_size) if args.save_losses else None)

        print('[Testing]')
        if args.save_losses:
            np.savez('losses/%s_%d.npz'%(args.data, seed), **trainer.losses)
        _, best_step = trainer.load_best()
        results = {}
        if args.aux_ent or args.aux_ent2:
            acc, f1, ent_acc, ent_f1, class_acc, ent_class_acc = trainer.evaluate(test_dataloader)[2:]
            print('Ent Acc:', ent_acc)
            print('Ent F1:', ent_f1)
            print("0: {:.4f}\n1: {:.4f}\n2: {:.4f}".format(*ent_class_acc))
            results['ent_acc'] = ent_acc
            results['ent_f1'] = ent_f1
            results['ent_acc_easy'] = ent_class_acc[0]
            results['ent_acc_med'] = ent_class_acc[1]
            results['ent_acc_hard'] = ent_class_acc[2]
        else:
            acc, confs, f1, class_acc = trainer.evaluate(test_dataloader)[2:]
        print('Acc:', acc)
        print('F1:', f1)
        print("0: {:.4f}\n1: {:.4f}\n2: {:.4f}".format(*class_acc))
        results['acc'] = acc
        results['f1'] = f1
        results['acc_easy'] = class_acc[0]
        results['acc_med'] = class_acc[1]
        results['acc_hard'] = class_acc[2]
        results['best_step'] = best_step
        if writer is not None:
            writer.set_params(args.__dict__, name = 'hparams')
            writer.set_params(results, name='result')
        trainer.cleanup()


def output_loss_conf():
    losses = []
    confs = []
    from tqdm import tqdm
    with torch.no_grad():
        for batch in tqdm(DataLoader(train_dataset, 1, False)):
            _, conf, loss_w, loss_unw = trainer.get_loss(batch)
            losses.append(loss_unw.item())
            if conf:
                confs.append(conf.item())
    out_file = os.path.basename(args.ckpt).replace('_best_meta.pt', '')
    np.save('cache/%s_loss.npy'%out_file, losses)
    np.save('cache/%s_conf.npy'%out_file, confs)
