from semantic_aug.datasets.coco import COCODataset
from semantic_aug.datasets.spurge import SpurgeDataset
from semantic_aug.datasets.imagenet import ImageNetDataset
from semantic_aug.datasets.pascal import PASCALDataset
from semantic_aug.datasets.caltech101 import CalTech101Dataset
from semantic_aug.datasets.flowers102 import Flowers102Dataset
from semantic_aug.datasets.pets import PetsDataset
from semantic_aug.datasets.cars import CarsDataset
from semantic_aug.datasets.lvis import LVISDataset
from semantic_aug.augmentations.compose import ComposeParallel
from semantic_aug.augmentations.compose import ComposeSequential
# from semantic_aug.augmentations.real_guidance import RealGuidance
# from semantic_aug.augmentations.textual_inversion import TextualInversion
from torch.utils.data import DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from transformers import AutoImageProcessor, DeiTModel
from itertools import product
from tqdm import trange
from typing import List
from config import get_exp_results

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as distributed
from torch.distributions import Categorical

import time
import json
import argparse
import pandas as pd
import numpy as np
import random
import os


class AugBatchSampler():
    def __init__(self, weights):
        self.update_weights(weights)

    def update_weights(self, weights):
        self.m = Categorical(weights)
        self.resample()

    def resample(self):
        self.winner = self.m.sample()
        
    def set_winner(self, winner):
        self.winner = winner

    def get_aug_batch(self, input, target):
        '''
        weights: [[A_real, A_same, A_other],
                [B_real, B_same, B_other],...]
        size: CX3

        input: [[img1_real, img1_same, img1_other],
                [img2_real, img2_same, img2_other],]
        size: BX1
        '''
        winner_idx = self.winner[target]
        input = torch.stack(input)
        batch = input[winner_idx, torch.arange(len(target))]
        return batch
    
try: 
    from cutmix.cutmix import CutMix
    IS_CUTMIX_INSTALLED = True
except:
    IS_CUTMIX_INSTALLED = False


DEFAULT_MODEL_PATH = "CompVis/stable-diffusion-v1-4"
DEFAULT_PROMPT = "a photo of a {name}"

DEFAULT_SYNTHETIC_DIR = "demo/data_aug/{dataset}-{seed}-{examples_per_class}"

DEFAULT_EMBED_PATH = "demo/embed/{dataset}-tokens/{dataset}-{seed}-{examples_per_class}.pt"

DATASETS = {
    "spurge": SpurgeDataset, 
    "coco": COCODataset, 
    "pascal": PASCALDataset,
    "imagenet": ImageNetDataset,
    "caltech": CalTech101Dataset,
    "flowers": Flowers102Dataset,
    "pets": PetsDataset,
    "cars": CarsDataset,
    'lvis': LVISDataset
}

COMPOSERS = {
    "parallel": ComposeParallel,
    "sequential": ComposeSequential
}

AUGMENTATIONS = {
    "real-guidance": None,
    "textual-inversion": None
}


def run_experiment(examples_per_class: float = 0, 
                   seed: int = 0, 
                   dataset: str = "spurge", 
                   num_synthetic: int = 100, 
                   iterations_per_epoch: int = 200, 
                   num_epochs: int = 50, 
                   batch_size: int = 32, 
                   aug: str = None,
                   strength: List[float] = None, 
                   guidance_scale: List[float] = None,
                   mask: List[bool] = None,
                   inverted: List[bool] = None, 
                   probs: List[float] = None,
                   compose: str = "parallel",
                   synthetic_probability: float = 0.5, 
                   synthetic_dir: dict = None, 
                   embed_path: str = DEFAULT_EMBED_PATH,
                   model_path: str = DEFAULT_MODEL_PATH,
                   prompt: str = DEFAULT_PROMPT,
                   use_randaugment: bool = False,
                   use_cutmix: bool = False,
                   erasure_ckpt_path: str = None,
                   image_size: int = 256,
                   classifier_backbone: str = "resnet50",
                   aug_prob: str = None,
                   search_lr: float = 0.001,
                   resample_freq: int = 1,
                   search_dir: str = None):

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # if aug is not None:

    #     aug = COMPOSERS[compose]([
            
    #         AUGMENTATIONS[aug](
    #             embed_path=embed_path, 
    #             model_path=model_path, 
    #             prompt=prompt, 
    #             strength=strength, 
    #             guidance_scale=guidance_scale,
    #             mask=mask, 
    #             inverted=inverted,
    #             erasure_ckpt_path=erasure_ckpt_path
    #         )

    #         for (aug, guidance_scale, 
    #              strength, mask, inverted) in zip(
    #             aug, guidance_scale, 
    #             strength, mask, inverted
    #         )

    #     ], probs=probs)

    train_dataset = DATASETS[dataset](
        split="train", examples_per_class=examples_per_class, 
        synthetic_probability=synthetic_probability, 
        synthetic_dir=synthetic_dir,
        use_randaugment=use_randaugment,
        generative_aug='autodiff', seed=seed,
        image_size=(image_size, image_size))

    if num_synthetic > 0 and aug == 'autodiff':
        train_dataset.load_augmentations(num_synthetic)
        n_data = len(train_dataset.synthetic_examples['samecls'])
        idx = list(train_dataset.synthetic_examples['samecls'].keys())[0]
        m = len(train_dataset.synthetic_examples['samecls'][idx])
        print(f"Loaded {m} synthetic samecls examples for {n_data} samples.")
        
        n_data = len(train_dataset.synthetic_examples['othercls'])
        idx = list(train_dataset.synthetic_examples['othercls'].keys())[0]
        m = len(train_dataset.synthetic_examples['othercls'][idx])
        print(f"Loaded {m} synthetic othercls examples for {n_data} samples.")

    aug_config = {}
    search_history = []
    search_parameters = [[0.33, 0.33, 0.33]]*train_dataset.num_classes
    if aug_prob == 'uniform':
        w = [[0.33, 0.33, 0.33]]*train_dataset.num_classes
        aug_config['weights'] = torch.tensor(w)
        # aug_info += f' w/ RF{resample_freq} w/ (0.33, 0.33, 0.33)'
    elif aug_prob == 'fixed':
        assert search_dir is not None
        w = [args.synthetic_probability]*train_dataset.num_classes
        aug_config['weights'] = torch.tensor(w)
        load_path = os.path.join(search_dir, f"searchparams-{dataset}-{seed}-{examples_per_class}.pt")
        search_parameters = torch.load(load_path)
        # aug_info += f' w/ RF{resample_freq} w/ {args.synthetic_probability}'
    elif aug_prob == 'search':
        w = [[0.33, 0.33, 0.33]]*train_dataset.num_classes
        aug_config['weights'] = torch.tensor(w)
        search_parameters = torch.Tensor([[0.33, 0.33, 0.33]]*train_dataset.num_classes).cuda()
        search_parameters.requires_grad = True
        params_to_update = [{'params': search_parameters, 'lr': search_lr}]
        search_optim = torch.optim.Adam(params_to_update) 
        
    cutmix_dataset = None
    if use_cutmix and IS_CUTMIX_INSTALLED:
        cutmix_dataset = CutMix(
            train_dataset, beta=1.0, prob=0.5, num_mix=2, 
            num_class=train_dataset.num_classes)

    train_sampler = torch.utils.data.RandomSampler(
        cutmix_dataset if cutmix_dataset is not None else 
        train_dataset, replacement=True, 
        num_samples=batch_size * iterations_per_epoch)

    train_dataloader = DataLoader(
        cutmix_dataset if cutmix_dataset is not None else 
        train_dataset, batch_size=batch_size, 
        sampler=train_sampler, num_workers=8)
    
    
    if aug_prob == 'search':
        search_seed = (seed + 1) % args.num_trials
        search_samecls_dir = synthetic_dir['samecls'].replace(
            f'{dataset}-{seed}', f'{dataset}-{search_seed}')
        search_othercls_dir = synthetic_dir['othercls'].replace(
            f'{dataset}-{seed}', f'{dataset}-{search_seed}')
        search_synthetic_dir = {
            'samecls': search_samecls_dir,
            'othercls': search_othercls_dir
        }
        search_dataset = DATASETS[dataset](
            split="train", examples_per_class=examples_per_class, 
            synthetic_probability=synthetic_probability, 
            synthetic_dir=search_synthetic_dir,
            use_randaugment=use_randaugment,
            generative_aug='autodiff', seed=search_seed,
            image_size=(image_size, image_size))
        
        search_dataset.load_augmentations(num_synthetic)
        n_data = len(search_dataset.synthetic_examples['samecls'])
        idx = list(search_dataset.synthetic_examples['samecls'].keys())[0]
        m = len(search_dataset.synthetic_examples['samecls'][idx])
        print(f"Loaded {m} synthetic *search* samecls examples for {n_data} samples.")
        
        n_data = len(search_dataset.synthetic_examples['othercls'])
        idx = list(search_dataset.synthetic_examples['othercls'].keys())[0]
        m = len(search_dataset.synthetic_examples['othercls'][idx])
        print(f"Loaded {m} synthetic *search* othercls examples for {n_data} samples.")
        
        search_sampler = torch.utils.data.RandomSampler(
            search_dataset, replacement=True, 
            num_samples=batch_size * iterations_per_epoch)

        search_dataloader = DataLoader(
            search_dataset, batch_size=batch_size, 
            sampler=search_sampler, num_workers=8)

    val_dataset = DATASETS[dataset](
        split="val", seed=seed,
        image_size=(image_size, image_size))

    val_sampler = torch.utils.data.RandomSampler(
        val_dataset, replacement=True, 
        num_samples=batch_size * iterations_per_epoch)

    val_dataloader = DataLoader(
        val_dataset, batch_size=batch_size, 
        sampler=val_sampler, num_workers=4)

    model = ClassificationModel(
        train_dataset.num_classes, 
        backbone=classifier_backbone
    ).cuda()

    optim = torch.optim.Adam(model.parameters(), lr=0.0001)

    records = []
    
    sampler = AugBatchSampler(aug_config['weights'])
    for epoch in trange(num_epochs, desc=f"Training Classifier for {dataset}-{seed}-{examples_per_class} w/ {aug_prob}{search_lr}"):

        model.train()

        epoch_loss = torch.zeros(
            train_dataset.num_classes, 
            dtype=torch.float32, device='cuda')
        epoch_accuracy = torch.zeros(
            train_dataset.num_classes, 
            dtype=torch.float32, device='cuda')
        epoch_size = torch.zeros(
            train_dataset.num_classes, 
            dtype=torch.float32, device='cuda')

        for step, (input, label) in enumerate(train_dataloader):
            if args.aug_prob == 'uniform':
                images = sampler.get_aug_batch(input, label)
                if step % resample_freq == 0:
                    sampler.resample()
            elif args.aug_prob in ['search', 'fixed']:
                # import IPython; IPython.embed(); exit(1)
                one_hot = F.gumbel_softmax(search_parameters, tau=1, hard=True).cpu()
                winner = torch.argmax(one_hot, dim=1)
                sampler.set_winner(winner)
                images = sampler.get_aug_batch(input, label)
                
            images, label = images.cuda(), label.cuda()
            logits = model(images)
            prediction = logits.argmax(dim=1)

            loss = F.cross_entropy(logits, label, reduction="none")
            if len(label.shape) > 1: label = label.argmax(dim=1)

            accuracy = (prediction == label).float()

            optim.zero_grad()
            loss.mean().backward()
            optim.step()

            with torch.no_grad():
            
                epoch_size.scatter_add_(0, label, torch.ones_like(loss))
                epoch_loss.scatter_add_(0, label, loss)
                epoch_accuracy.scatter_add_(0, label, accuracy)

        training_loss = epoch_loss / epoch_size.clamp(min=1)
        training_accuracy = epoch_accuracy / epoch_size.clamp(min=1)

        training_loss = training_loss.cpu().numpy()
        training_accuracy = training_accuracy.cpu().numpy()

        model.eval()

        epoch_loss = torch.zeros(
            train_dataset.num_classes, 
            dtype=torch.float32, device='cuda')
        epoch_accuracy = torch.zeros(
            train_dataset.num_classes, 
            dtype=torch.float32, device='cuda')
        epoch_size = torch.zeros(
            train_dataset.num_classes, 
            dtype=torch.float32, device='cuda')

        if aug_prob == 'search' and epoch % resample_freq == 0: 
            for step, (input, label) in enumerate(search_dataloader):
                input = torch.stack(input)  # torch.Size([3, 32, 3, 256, 256])
                input, label = input.cuda(), label.cuda()
                soft_y = F.gumbel_softmax(search_parameters, tau=1, hard=False).cuda()
                '''
                    cls A: w1, w2, w3,
                    cls B: w1, w2, w3, ...(c)
                    
                    data 1: w1, w2, w3,
                    data 2: w1, w2, w3, ...(b)
                    
                    input 1: x1, x2, x3,
                    input 2: x1, x2, x3, ...(b)
                '''
                weights = soft_y[label]
                weightsTu = weights.T[:, :, None, None, None]  # torch.Size([3, 32, 1, 1, 1])
                product = torch.multiply(input, weightsTu)
                img_sum = torch.sum(product, dim=0)
                
                logits = model(img_sum, search=True)
                loss = F.cross_entropy(logits, label, reduction="none")

                search_optim.zero_grad()
                loss.mean().backward()
                search_optim.step()
                
            gumbel_param = search_parameters.detach().cpu().numpy()
            for i, name in enumerate(train_dataset.class_names):
                search_history.append(dict(
                    seed=seed, 
                    examples_per_class=examples_per_class,
                    epoch=epoch,
                    label=f"{name.title()}", 
                    identity_p=gumbel_param[i][0],  # i-th row (class) j-th element (real, samecls, othercls)
                    samecls_p=gumbel_param[i][1],
                    othercls_p=gumbel_param[i][2]
                ))
        
        for image, label in val_dataloader:
            image, label = image.cuda(), label.cuda()

            logits = model(image)
            prediction = logits.argmax(dim=1)

            loss = F.cross_entropy(logits, label, reduction="none")
            accuracy = (prediction == label).float()

            with torch.no_grad():
            
                epoch_size.scatter_add_(0, label, torch.ones_like(loss))
                epoch_loss.scatter_add_(0, label, loss)
                epoch_accuracy.scatter_add_(0, label, accuracy)

        validation_loss = epoch_loss / epoch_size.clamp(min=1)
        validation_accuracy = epoch_accuracy / epoch_size.clamp(min=1)

        validation_loss = validation_loss.cpu().numpy()
        validation_accuracy = validation_accuracy.cpu().numpy()

        records.append(dict(
            seed=seed, 
            examples_per_class=examples_per_class,
            epoch=epoch, 
            value=training_loss.mean(), 
            metric="Loss", 
            split="Training"
        ))

        records.append(dict(
            seed=seed, 
            examples_per_class=examples_per_class,
            epoch=epoch, 
            value=validation_loss.mean(), 
            metric="Loss", 
            split="Validation"
        ))

        records.append(dict(
            seed=seed, 
            examples_per_class=examples_per_class,
            epoch=epoch, 
            value=training_accuracy.mean(), 
            metric="Accuracy", 
            split="Training"
        ))

        records.append(dict(
            seed=seed, 
            examples_per_class=examples_per_class,
            epoch=epoch, 
            value=validation_accuracy.mean(), 
            metric="Accuracy", 
            split="Validation"
        ))

        for i, name in enumerate(train_dataset.class_names):

            records.append(dict(
                seed=seed, 
                examples_per_class=examples_per_class,
                epoch=epoch, 
                value=training_loss[i], 
                metric=f"Loss {name.title()}", 
                split="Training"
            ))

            records.append(dict(
                seed=seed, 
                examples_per_class=examples_per_class,
                epoch=epoch, 
                value=validation_loss[i], 
                metric=f"Loss {name.title()}", 
                split="Validation"
            ))

            records.append(dict(
                seed=seed, 
                examples_per_class=examples_per_class,
                epoch=epoch, 
                value=training_accuracy[i], 
                metric=f"Accuracy {name.title()}", 
                split="Training"
            ))

            records.append(dict(
                seed=seed, 
                examples_per_class=examples_per_class,
                epoch=epoch, 
                value=validation_accuracy[i], 
                metric=f"Accuracy {name.title()}", 
                split="Validation"
            ))
            
    return records, search_history, search_parameters


class ClassificationModel(nn.Module):
    
    def __init__(self, num_classes: int, backbone: str = "resnet50"):
        
        super(ClassificationModel, self).__init__()

        self.backbone = backbone
        self.image_processor  = None

        if backbone == "resnet50":
        
            self.base_model = resnet50(weights=ResNet50_Weights.DEFAULT)
            self.out = nn.Linear(2048, num_classes)

        elif backbone == "deit":

            self.base_model = DeiTModel.from_pretrained(
                "facebook/deit-base-distilled-patch16-224")
            self.out = nn.Linear(768, num_classes)

    def img_to_feat(self, image):
        x = image

        if self.backbone == "resnet50":
    
            x = self.base_model.conv1(x)
            x = self.base_model.bn1(x)
            x = self.base_model.relu(x)
            x = self.base_model.maxpool(x)

            x = self.base_model.layer1(x)
            x = self.base_model.layer2(x)
            x = self.base_model.layer3(x)
            x = self.base_model.layer4(x)

            x = self.base_model.avgpool(x)
            x = torch.flatten(x, 1)

        elif self.backbone == "deit":
            
            x = self.base_model(x)[0][:, 0, :]
            
        return x
 
    def forward(self, image, search=False):
        if search:
            x = self.img_to_feat(image)
        else:
            with torch.no_grad():
                x = self.img_to_feat(image)       
        return self.out(x)


if __name__ == "__main__":

    parser = argparse.ArgumentParser("Few-Shot Baseline")

    parser.add_argument("--logdir", type=str, default="few_shot_combined")
    parser.add_argument("--model-path", type=str, default="CompVis/stable-diffusion-v1-4")

    parser.add_argument("--prompt", type=str, default="a photo of a {name}")

    parser.add_argument("--synthetic-probability", type=float, default=0.5)
    parser.add_argument("--synthetic-samecls-dir", type=str, default=DEFAULT_SYNTHETIC_DIR)
    parser.add_argument("--synthetic-othercls-dir", type=str, default=DEFAULT_SYNTHETIC_DIR)
    
    parser.add_argument("--search-dir", type=str, default=None)
    
    parser.add_argument("--image-size", type=int, default=256)
    parser.add_argument("--classifier-backbone", type=str, 
                        default="resnet50", choices=["resnet50", "deit"])

    parser.add_argument("--iterations-per-epoch", type=int, default=200)
    parser.add_argument("--num-epochs", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=32)

    parser.add_argument("--num-synthetic", type=int, default=15)
    parser.add_argument("--num-trials", type=int, default=8)
    parser.add_argument("--start-seed", type=int, default=0)
    parser.add_argument("--examples-per-class", nargs='+', type=float, default=[1, 2, 4, 8, 16])
    
    parser.add_argument("--embed-path", type=str, default=DEFAULT_EMBED_PATH)
    
    parser.add_argument("--dataset", type=str, default="pascal", 
                        choices=["spurge", "imagenet", "coco", "pascal", "flowers", "caltech",  "pets", "cars", "lvis"])
    
    parser.add_argument("--aug", type=str, default=None, 
                        choices=["real-guidance", "textual-inversion", "autodiff"])

    parser.add_argument("--strength", nargs="+", type=float, default=None)
    parser.add_argument("--guidance-scale", nargs="+", type=float, default=None)

    parser.add_argument("--mask", nargs="+", type=int, default=None, choices=[0, 1])
    parser.add_argument("--inverted", nargs="+", type=int, default=None, choices=[0, 1])
    
    parser.add_argument("--probs", nargs="+", type=float, default=None)
    
    parser.add_argument("--compose", type=str, default="parallel", 
                        choices=["parallel", "sequential"])
    
    parser.add_argument("--aug-prob", type=str, default="uniform", 
                        choices=["uniform", "fixed", "search"])
    parser.add_argument("--search-lr", type=float, default=0.001)
    parser.add_argument("--resample-freq", type=int, default=1)
    
    parser.add_argument("--erasure-ckpt-path", type=str, default=None)

    parser.add_argument("--use-randaugment", action="store_true")
    parser.add_argument("--use-cutmix", action="store_true")
    
    args = parser.parse_args()

    try:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
    except KeyError:
        rank, world_size = 0, 1

    device_id = rank % torch.cuda.device_count()
    torch.cuda.set_device(rank % torch.cuda.device_count())

    print(f'Initialized process {rank} / {world_size}')
    os.makedirs(args.logdir, exist_ok=True)

    all_trials = []
    all_history = []

    options = product(range(args.num_trials), args.examples_per_class)
    options = np.array(list(options))
    options = np.array_split(options, world_size)[rank]
    
    exp_time = '{}'.format(time.strftime("%Y%m%d-%H%M%S"))
    out_dir = os.path.join(args.logdir, args.dataset, f'{exp_time}-{args.dataset}-{args.examples_per_class}-{args.aug_prob}{args.search_lr}')
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, "config.json")
    with open(path, 'wt') as f:
        json.dump(vars(args), f, indent=4)

    for seed, examples_per_class in options.tolist():
        seed = int(seed)
        if seed < args.start_seed:
            continue
        examples_per_class = int(examples_per_class) if examples_per_class >=1 else examples_per_class
        method_key = f'search{args.search_lr}' if args.aug_prob == 'search' else args.aug_prob
        exp_path = get_exp_results(args.dataset, examples_per_class, method_key)
        if exp_path != '':
            print(f'Experiment exists in {exp_path}')
            continue
            
        hyperparameters = dict(
            examples_per_class=examples_per_class,
            seed=seed, 
            dataset=args.dataset,
            num_epochs=args.num_epochs,
            iterations_per_epoch=args.iterations_per_epoch, 
            batch_size=args.batch_size,
            model_path=args.model_path,
            synthetic_probability=args.synthetic_probability, 
            num_synthetic=args.num_synthetic, 
            prompt=args.prompt, 
            aug=args.aug,
            strength=args.strength, 
            guidance_scale=args.guidance_scale,
            mask=args.mask, 
            inverted=args.inverted,
            probs=args.probs,
            compose=args.compose,
            use_randaugment=args.use_randaugment,
            use_cutmix=args.use_cutmix,
            erasure_ckpt_path=args.erasure_ckpt_path,
            image_size=args.image_size,
            classifier_backbone=args.classifier_backbone,
            aug_prob=args.aug_prob,
            search_lr=args.search_lr,
            resample_freq=args.resample_freq,
            search_dir=args.search_dir)

        synthetic_samecls_dir = args.synthetic_samecls_dir.format(**hyperparameters)
        synthetic_othercls_dir = args.synthetic_othercls_dir.format(**hyperparameters)
        synthetic_dir = {
            'samecls': synthetic_samecls_dir,
            'othercls': synthetic_othercls_dir
        }
        embed_path = args.embed_path.format(**hyperparameters)
        record, history, search_parameters = run_experiment(
            synthetic_dir=synthetic_dir, 
            embed_path=embed_path, **hyperparameters)
        
        all_trials.extend(record)
        all_history.extend(history)
        
        aug_tag = 'd'
        if args.aug == 'textual-inversion' and args.num_synthetic > 0:
            aug_tag += 't'
        elif args.aug == 'autodiff' and args.num_synthetic > 0:
            aug_tag += f'a({args.aug_prob})'
        if args.use_cutmix:
            aug_tag += 'c'
        if args.use_randaugment:
            aug_tag += 'r'
        if args.use_cutmix:
            aug_tag += 'c'
        if args.use_randaugment:
            aug_tag += 'r'

        path = f"results_s{seed}_e{examples_per_class}_m{args.num_synthetic}_p{args.synthetic_probability}_{aug_tag}.csv"
        path = os.path.join(out_dir, path)
        pd.DataFrame.from_records(all_trials).to_csv(path)
        
        if args.aug_prob == 'search':
            path = f"history_s{seed}_e{examples_per_class}_rf{args.resample_freq}_searchlr{args.search_lr}_{aug_tag}.csv"
            path = os.path.join(out_dir, path)
            pd.DataFrame.from_records(all_history).to_csv(path)
            
            weight_path = f"searchparams-{args.dataset}-{seed}-{examples_per_class}.pt"
            weight_path = os.path.join(out_dir, weight_path)
            torch.save(search_parameters, weight_path)
        
        print(f"[rank {rank}] n={examples_per_class} saved to: {path}")
