import os
import pdb
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from tqdm import tqdm
from mtt_utils import get_dataset, get_network, get_eval_pool, evaluate_synset, get_time, DiffAugment, ParamDiffAug, get_real_data
import wandb
import copy
import random
from natsort import natsorted as nt
from glob import glob
from mtt_reparam_module import ReparamModule

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def main(args):

    if args.zca and args.texture:
        raise AssertionError("Cannot use zca and texture together")

    if args.texture and args.pix_init == "real":
        print("WARNING: Using texture with real initialization will take a very long time to smooth out the boundaries between images.")

    if args.max_experts is not None and args.max_files is not None:
        args.total_experts = args.max_experts * args.max_files

    print("CUDNN STATUS: {}".format(torch.backends.cudnn.enabled))

    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    im_res = im_size[0]

    args.im_size = im_size

    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    data_save = []

    if args.dsa:
        # args.epoch_eval_train = 1000
        args.dc_aug_param = None

    args.dsa_param = ParamDiffAug()

    dsa_params = args.dsa_param
    if args.zca:
        zca_trans = args.zca_trans
    else:
        zca_trans = None

    assert args.wandb_name != 'None', 'Please specify the wandb name.'
    wandb.login()
    run = wandb.init(
        project="Domain Condensation Domain Test!",
        name=args.wandb_name,
        config=args,
    )

    args = type('', (), {})()

    for key in wandb.config._items:
        setattr(args, key, wandb.config._items[key])

    args.dsa_param = dsa_params
    args.zca_trans = zca_trans

    if args.batch_syn is None:
        args.batch_syn = num_classes * args.ipc

    args.distributed = torch.cuda.device_count() > 1


    print('Hyper-parameters: \n', args.__dict__)
    print('Evaluation model pool: ', model_eval_pool)

    images_all, labels_all, indices_class, *temp = get_real_data(args, dst_train, num_classes)

    if len(temp) > 0:
        domains_all, indices_domain = temp

    if not os.path.exists(f'saved_data/{args.dataset}_indices_pseudo_domain_250109.pt'):
        pseudo_domain_lbl = []
        for single_img in images_all:
            fft = torch.fft.fft2(single_img)
            fft = torch.fft.fftshift(fft, dim=(-2, -1))
            x_mean, y_mean = np.ceil(single_img.shape[1]*0.09), np.ceil(single_img.shape[2]*0.09)
            low_freq = fft[:, int(single_img.shape[1]//2-x_mean):int(single_img.shape[1]//2+x_mean), int(single_img.shape[2]//2-y_mean):int(single_img.shape[2]//2+y_mean)]
            mean = torch.mean(torch.abs(low_freq))
            pseudo_domain_lbl.append(mean)
        pseudo_domain_lbl = torch.tensor(pseudo_domain_lbl)
        temp_domain_idx = torch.argsort(pseudo_domain_lbl).numpy()
        indices_pseudo_domain = [[] for d in range(4)]
        for _domain_num in range(4):
            indices_pseudo_domain[_domain_num] = temp_domain_idx[_domain_num*(len(images_all)//4):(_domain_num+1)*(len(images_all)//4)]
        torch.save(indices_pseudo_domain, f'saved_data/{args.dataset}_indices_pseudo_domain_250109.pt')
        indices_pseudo_domain = torch.load(f'saved_data/{args.dataset}_indices_pseudo_domain_250109.pt')
    else:
        indices_pseudo_domain = torch.load(f'saved_data/{args.dataset}_indices_pseudo_domain_250109.pt')

    domain_index_label_map = {}

    for i, sublist in enumerate(indices_pseudo_domain):
        for index in sublist:
            domain_index_label_map[index] = i

    for c in range(num_classes):
        print('class c = %d: %d real images'%(c, len(indices_class[c])))

    for ch in range(channel):
        print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


    def get_images(c, n):  # get random n images from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return images_all[idx_shuffle]


    ''' initialize the synthetic data '''
    label_syn = torch.tensor([np.ones(args.ipc,dtype=np.int_)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

    image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float)

    syn_lr = torch.tensor(args.lr_teacher).to(args.device)
    syn_domain_lr = torch.tensor(args.lr_teacher).to(args.device)

    if args.pix_init == 'real':
        print('initialize synthetic data from random real images')
        for c in range(num_classes):
            image_syn.data[c * args.ipc:(c + 1) * args.ipc] = get_images(c, args.ipc).detach().data
    else:
        print('initialize synthetic data from random noise')


    ''' training '''
    image_syn = image_syn.detach().to(args.device).requires_grad_(True)
    syn_lr = syn_lr.detach().to(args.device).requires_grad_(True)
    syn_domain_lr = syn_domain_lr.detach().to(args.device).requires_grad_(True)
    optimizer_img = torch.optim.SGD([image_syn], lr=args.lr_img, momentum=0.5)
    optimizer_lr = torch.optim.SGD([syn_lr], lr=args.lr_lr, momentum=0.5)
    optimizer_domain_lr = torch.optim.SGD([syn_domain_lr], lr=args.lr_lr, momentum=0.5)
    optimizer_img.zero_grad()

    domain_masks = {}
    for _ in range (4):
        domain_masks[_] = [torch.ones(image_syn.size(), device=args.device)*args.mask_init]
        domain_masks[_][0] = domain_masks[_][0].detach().to(args.device).requires_grad_(True)
        domain_masks[_].append(torch.optim.SGD([domain_masks[_][0]], lr=args.lr_img, momentum=0.5))

    criterion = nn.CrossEntropyLoss().to(args.device)
    print('%s training begins'%get_time())

    expert_dir = os.path.join(args.buffer_path)
    if args.dataset == "ImageNet":
        expert_dir = os.path.join(expert_dir, args.subset, str(args.res))
    if args.dataset in ["CIFAR10", "CIFAR100"] and not args.zca:
        expert_dir += "_NO_ZCA"
    expert_dir = os.path.join(expert_dir, args.model)
    print("Expert Dir: {}".format(expert_dir))

    if args.load_all:
        buffer_class = []
        buffer_domain = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}_class.pt".format(n))):
            buffer_class = buffer_class + torch.load(os.path.join(expert_dir, "replay_buffer_{}_class.pt".format(n)))
            buffer_domain = buffer_domain + torch.load(os.path.join(expert_dir, "replay_buffer_{}_domain.pt".format(n)))
            n += 1
        if n == 0:
            raise AssertionError("No buffers detected at {}".format(expert_dir))

    else:
        expert_files_class = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}_class.pt".format(n))):
            expert_files_class.append(os.path.join(expert_dir, "replay_buffer_{}_class.pt".format(n)))
            n += 1
        if n == 0:
            raise AssertionError("No buffers detected at {}".format(expert_dir))
        file_idx_class = 0
        expert_idx_class = 0
        random.shuffle(expert_files_class)
        if args.max_files is not None:
            expert_files_class = expert_files_class[:args.max_files]
        print("loading file {}".format(expert_files_class[file_idx_class]))
        buffer_class = torch.load(expert_files_class[file_idx_class])
        if args.max_experts is not None:
            buffer_class = buffer_class[:args.max_experts]
        random.shuffle(buffer_class)

        expert_files_domain = []
        n = 0
        while os.path.exists(os.path.join(expert_dir, "replay_buffer_{}_domain.pt".format(n))):
            expert_files_domain.append(os.path.join(expert_dir, "replay_buffer_{}_domain.pt".format(n)))
            n += 1
        if n == 0:
            raise AssertionError("No buffers detected at {}".format(expert_dir))
        file_idx_domain = 0
        expert_idx_domain = 0
        random.shuffle(expert_files_domain)
        if args.max_files is not None:
            expert_files_domain = expert_files_domain[:args.max_files]
        print("loading file {}".format(expert_files_domain[file_idx_domain]))
        buffer_domain = torch.load(expert_files_domain[file_idx_domain])
        if args.max_experts is not None:
            buffer_domain = buffer_domain[:args.max_experts]
        random.shuffle(buffer_domain)

    best_acc = {m: 0 for m in model_eval_pool}

    best_std = {m: 0 for m in model_eval_pool}

    save_pth = f'../MDDC/results/ICML_MTT_DOMAIN_EMBEDDING_ConvNet/{args.dataset}/ipc{args.ipc}'
    try:
        temp_int = int(nt(glob(f'{save_pth}/*'))[-1].split('/')[-1]) + 1
    except IndexError:
        temp_int = 0
    save_pth = f'{save_pth}/{temp_int}'
    os.makedirs(save_pth, exist_ok=True)
    # log_file = open(f'{save_pth}/log_5000.txt', 'w')

    for it in range(0, args.Iteration+1):
        save_this_it = False

        wandb.log({"Progress": it}, step=it)
        ''' Evaluate synthetic data '''
        if it in eval_it_pool:
            with torch.no_grad():
                image_save = image_syn.cuda()

                save_dir = os.path.join(".", "logged_files", args.dataset, wandb.run.name)

                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)

                torch.save(image_save.cpu(), os.path.join(save_dir, "images_{}.pt".format(it)))
                torch.save(label_syn.cpu(), os.path.join(save_dir, "labels_{}.pt".format(it)))

                if save_this_it:
                    torch.save(image_save.cpu(), os.path.join(save_dir, "images_best.pt".format(it)))
                    torch.save(label_syn.cpu(), os.path.join(save_dir, "labels_best.pt".format(it)))

                wandb.log({"Pixels": wandb.Histogram(torch.nan_to_num(image_syn.detach().cpu()))}, step=it)

                if args.ipc < 50 or args.force_save:
                    upsampled = image_save
                    if args.dataset != "ImageNet" or args.dataset != 'OH_in_dist' or args.dataset != 'VLCS_in_dist' or args.dataset != 'PACS_in_dist':
                        upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                        upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                    grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                    wandb.log({"Synthetic_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)
                    wandb.log({'Synthetic_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)

                    for clip_val in [2.5]:
                        std = torch.std(image_save)
                        mean = torch.mean(image_save)
                        upsampled = torch.clip(image_save, min=mean-clip_val*std, max=mean+clip_val*std)
                        if args.dataset != "ImageNet" or args.dataset != 'OH_in_dist' or args.dataset != 'VLCS_in_dist' or args.dataset != 'PACS_in_dist':
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                        grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                        wandb.log({"Clipped_Synthetic_Images/std_{}".format(clip_val): wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)

                    if args.zca:
                        image_save = image_save.to(args.device)
                        image_save = args.zca_trans.inverse_transform(image_save)
                        image_save.cpu()

                        torch.save(image_save.cpu(), os.path.join(save_dir, "images_zca_{}.pt".format(it)))

                        upsampled = image_save
                        if args.dataset != "ImageNet" or args.dataset != 'OH_in_dist' or args.dataset != 'VLCS_in_dist' or args.dataset != 'PACS_in_dist':
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                            upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                        grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                        wandb.log({"Reconstructed_Images": wandb.Image(torch.nan_to_num(grid.detach().cpu()))}, step=it)
                        wandb.log({'Reconstructed_Pixels': wandb.Histogram(torch.nan_to_num(image_save.detach().cpu()))}, step=it)

                        for clip_val in [2.5]:
                            std = torch.std(image_save)
                            mean = torch.mean(image_save)
                            upsampled = torch.clip(image_save, min=mean - clip_val * std, max=mean + clip_val * std)
                            if args.dataset != "ImageNet" or args.dataset != 'OH_in_dist' or args.dataset != 'VLCS_in_dist' or args.dataset != 'PACS_in_dist':
                                upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=2)
                                upsampled = torch.repeat_interleave(upsampled, repeats=4, dim=3)
                            grid = torchvision.utils.make_grid(upsampled, nrow=10, normalize=True, scale_each=True)
                            wandb.log({"Clipped_Reconstructed_Images/std_{}".format(clip_val): wandb.Image(
                                torch.nan_to_num(grid.detach().cpu()))}, step=it)

        wandb.log({"Synthetic_LR": syn_lr.detach().cpu()}, step=it)
        wandb.log({"Synthetic_LR_Domain": syn_domain_lr.detach().cpu()}, step=it)

        student_net_class = get_network(args.model, channel, num_classes, im_size, dist=False, domain=False).to(args.device)  # get a random model
        student_net_domain = get_network(args.model, channel, num_classes, im_size, dist=False, domain=True).to(args.device)  # get a random model

        student_net_class = ReparamModule(student_net_class)
        student_net_domain = ReparamModule(student_net_domain)

        if args.distributed:
            student_net_class = torch.nn.DataParallel(student_net_class)
            student_net_domain = torch.nn.DataParallel(student_net_domain)

        student_net_class.train()
        student_net_domain.train()

        num_params_class = sum([np.prod(p.size()) for p in (student_net_class.parameters())])
        num_params_domain = sum([np.prod(p.size()) for p in (student_net_domain.parameters())])

        if args.load_all:
            expert_trajectory_class = buffer_class[np.random.randint(0, len(buffer_class))]
            expert_trajectory_domain = buffer_domain[np.random.randint(0, len(buffer_domain))]
        else:
            expert_trajectory_class = buffer_class[expert_idx_class]
            expert_idx_class += 1
            if expert_idx_class == len(buffer_class):
                expert_idx_class = 0
                file_idx_class += 1
                if file_idx_class == len(expert_files_class):
                    file_idx_class = 0
                    random.shuffle(expert_files_class)
                print("loading file {}".format(expert_files_class[file_idx_class]))
                if args.max_files != 1:
                    del buffer_class
                    buffer_class = torch.load(expert_files_class[file_idx_class])
                if args.max_experts is not None:
                    buffer_class = buffer_class[:args.max_experts]
                random.shuffle(buffer_class)
            expert_trajectory_domain = buffer_domain[expert_idx_domain]
            expert_idx_domain += 1
            if expert_idx_domain == len(buffer_domain):
                expert_idx_domain = 0
                file_idx_domain += 1
                if file_idx_domain == len(expert_files_domain):
                    file_idx_domain = 0
                    random.shuffle(expert_files_domain)
                print("loading file {}".format(expert_files_domain[file_idx_domain]))
                if args.max_files != 1:
                    del buffer_domain
                    buffer_domain = torch.load(expert_files_domain[file_idx_domain])
                if args.max_experts is not None:
                    buffer_domain = buffer_domain[:args.max_experts]
                random.shuffle(buffer_domain)

        start_epoch = np.random.randint(0, args.max_start_epoch)
        start_epoch_domain = np.random.randint(0, 1)
        starting_params_class = expert_trajectory_class[start_epoch]
        starting_params_domain = expert_trajectory_domain[start_epoch_domain]

        target_params_class = expert_trajectory_class[start_epoch+args.expert_epochs]
        target_params_class = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params_class], 0)
        target_params_domain = expert_trajectory_domain[start_epoch_domain+args.expert_epochs]
        target_params_domain = torch.cat([p.data.to(args.device).reshape(-1) for p in target_params_domain], 0)

        student_params_class = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params_class], 0).requires_grad_(True)]
        starting_params_class = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params_class], 0)
        
        student_params_domain = [torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params_domain], 0).requires_grad_(True)]
        starting_params_domain = torch.cat([p.data.to(args.device).reshape(-1) for p in starting_params_domain], 0)

        syn_images = image_syn
        y_hat = label_syn.to(args.device)

        param_loss_list = []
        param_dist_list = []
        indices_chunks = []

        for step in range(args.syn_steps):

            if not indices_chunks:
                indices = torch.randperm(len(syn_images))
                indices_chunks = list(torch.split(indices, args.batch_syn))

            these_indices = indices_chunks.pop()


            x = syn_images[these_indices]
            this_y = y_hat[these_indices]

            if args.dsa and (not args.no_aug):
                x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param)

            if args.distributed:
                forward_params = student_params_class[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
            else:
                forward_params = student_params_class[-1]
            x = student_net_class(x, flat_param=forward_params)
            ce_loss = criterion(x, this_y)

            grad = torch.autograd.grad(ce_loss, student_params_class[-1], create_graph=True, allow_unused=True)[0]

            student_params_class.append(student_params_class[-1] - syn_lr * grad)


        param_loss = torch.tensor(0.0).to(args.device)
        param_dist = torch.tensor(0.0).to(args.device)

        param_loss += torch.nn.functional.mse_loss(student_params_class[-1], target_params_class, reduction="sum")
        param_dist += torch.nn.functional.mse_loss(starting_params_class, target_params_class, reduction="sum")

        param_loss_list.append(param_loss)
        param_dist_list.append(param_dist)


        param_loss /= num_params_class
        param_dist /= num_params_class

        param_loss /= param_dist

        grand_loss = param_loss

        optimizer_img.zero_grad()
        optimizer_lr.zero_grad()

        grand_loss.backward(retain_graph=True)

        optimizer_img.step()
        optimizer_lr.step()

        if it%10 == 0:
            print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item()))

        syn_images = image_syn

        param_loss_list = []
        param_dist_list = []
        indices_chunks = []

        _domain_masks = torch.nn.functional.softmax(torch.stack([item[0] for item in domain_masks.values()], dim=0)/args.temperature, dim=0)
        
        for step in range(args.syn_steps):

            if not indices_chunks:
                indices = torch.randperm(len(syn_images))
                indices_chunks = list(torch.split(indices, args.batch_syn))

            these_indices = indices_chunks.pop()
            domain_labels = torch.arange(len(_domain_masks)).repeat(int(np.ceil(args.batch_syn/len(_domain_masks))))
            random.shuffle(domain_labels)
            domain_labels = domain_labels[:args.batch_syn]
            domain_labels = domain_labels[these_indices].to(args.device)
            
            x = syn_images[these_indices]
            x = torch.stack([x[_temp_index1]*_domain_masks[output[1]][output[0]] for _temp_index1, output in enumerate(zip(these_indices, domain_labels))], dim=0)

            if args.dsa and (not args.no_aug):
                x = DiffAugment(x, args.dsa_strategy, param=args.dsa_param)

            if args.distributed:
                forward_params = student_params_domain[-1].unsqueeze(0).expand(torch.cuda.device_count(), -1)
            else:
                forward_params = student_params_domain[-1]
            x = student_net_domain(x, flat_param=forward_params, domain=True)
            ce_loss = criterion(x, domain_labels) * 0.01

            grad = torch.autograd.grad(ce_loss, student_params_domain[-1], create_graph=True, allow_unused=True)[0]

            student_params_domain.append(student_params_domain[-1] - syn_domain_lr * grad)


        param_loss = torch.tensor(0.0).to(args.device)
        param_dist = torch.tensor(0.0).to(args.device)

        param_loss += torch.nn.functional.mse_loss(student_params_domain[-1], target_params_domain, reduction="sum")
        param_dist += torch.nn.functional.mse_loss(starting_params_domain, target_params_domain, reduction="sum")

        param_loss_list.append(param_loss)
        param_dist_list.append(param_dist)


        param_loss /= num_params_domain
        param_dist /= num_params_domain

        param_loss /= param_dist

        grand_loss = param_loss * 0.1

        for _domain in range (len(domain_masks)):
            domain_masks[_domain][1].zero_grad()
            grand_loss.backward(retain_graph=True)
            domain_masks[_domain][1].step()

        optimizer_img.zero_grad()
        optimizer_domain_lr.zero_grad()

        grand_loss.backward()

        optimizer_domain_lr.step()
        optimizer_img.step()

        wandb.log({"Grand_Loss": grand_loss.detach().cpu(),
                "Start_Epoch": start_epoch})

        for _ in student_params_class:
            del _
        for _ in student_params_domain:
            del _

        if it%10 == 0:
            print('%s iter = %04d, loss = %.4f' % (get_time(), it, grand_loss.item()))
        if it % 1000 == 0:
            torch.save({'data': [copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())]}, os.path.join(save_pth, f'{it}.pt'))


    # log_file.close()
    print(save_pth)
    wandb.finish()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')

    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')

    parser.add_argument('--subset', type=str, default='imagenette', help='ImageNet subset. This only does anything when --dataset=ImageNet')

    parser.add_argument('--model', type=str, default='ConvNet', help='model')

    parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')

    parser.add_argument('--eval_mode', type=str, default='S',
                        help='eval_mode, check utils.py for more info')

    parser.add_argument('--num_eval', type=int, default=20, help='how many networks to evaluate on')

    parser.add_argument('--eval_it', type=int, default=1000, help='how often to evaluate')

    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=5000, help='how many distillation steps to perform')

    parser.add_argument('--lr_img', type=float, default=1000, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_lr', type=float, default=1e-05, help='learning rate for updating... learning rate')
    parser.add_argument('--lr_teacher', type=float, default=0.01, help='initialization for synthetic learning rate')

    parser.add_argument('--lr_init', type=float, default=0.01, help='how to init lr (alpha)')

    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_syn', type=int, default=None, help='should only use this if you run out of VRAM')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')

    parser.add_argument('--pix_init', type=str, default='real', choices=["noise", "real"],
                        help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')

    parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'],
                        help='whether to use differentiable Siamese augmentation.')

    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate',
                        help='differentiable Siamese augmentation strategy')

    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--buffer_path', type=str, default='./buffers', help='buffer path')

    parser.add_argument('--expert_epochs', type=int, default=3, help='how many expert epochs the target params are')
    parser.add_argument('--syn_steps', type=int, default=20, help='how many steps to take on synthetic data')
    parser.add_argument('--max_start_epoch', type=int, default=25, help='max epoch we can start at')

    parser.add_argument('--zca', action='store_true', help="do ZCA whitening")

    parser.add_argument('--load_all', action='store_true', help="only use if you can fit all expert trajectories into RAM")

    parser.add_argument('--no_aug', type=bool, default=False, help='this turns off diff aug during distillation')

    parser.add_argument('--texture', action='store_true', help="will distill textures instead")
    parser.add_argument('--canvas_size', type=int, default=2, help='size of synthetic canvas')
    parser.add_argument('--canvas_samples', type=int, default=1, help='number of canvas samples per iteration')


    parser.add_argument('--max_files', type=int, default=None, help='number of expert files to read (leave as None unless doing ablations)')
    parser.add_argument('--max_experts', type=int, default=None, help='number of experts to read per file (leave as None unless doing ablations)')

    parser.add_argument('--force_save', action='store_true', help='this will save images for 50ipc')

    parser.add_argument('--wandb_name', type=str, default=None)
    parser.add_argument('--exp', type=int, default=None)
    parser.add_argument('--temperature', type=float, default=0.1)
    parser.add_argument('--mask_init', type=float, default=0.01)

    args = parser.parse_args()

    main(args)


