import einops
import ml_collections
from dataloader import get_dataset
    
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import pickle
import numpy as np
from tqdm import tqdm
import math

import os
import fire
from PIL import Image
from cov_stats import CovStatsHook, get_stats_conv, scale_cov_hooks, scale_cov_hooks_bwd #, matrix_pow
from modified_resnets import resnet18_silu
from collections import OrderedDict
# import wandb
from wide_resnets2 import WideResNet
import torch.nn.functional as F
import einops
from compute_ds_stats import get_cov_stat_hooks


from pytorch_metric_learning import losses


loss_func = losses.NTXentLoss()

def get_config():
    config = ml_collections.ConfigDict()
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = True

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'imagenet'
    config.dataset.data_path = 'data/tensorflow_datasets'

    return config


def get_cached_inverse_info(stats_dict):
    with torch.no_grad():
        ds_inverses_model = []
        ds_inverses2_model = []
        ds_logdets_model = []
        ds_class_inverses_model = []
        ds_class_logdets_model = []
            
        for csh_i, (ds_mean, ds_cov, ds_class_mean, ds_class_cov) in enumerate(zip(stats_dict.global_means, stats_dict.global_covs, stats_dict.class_means, stats_dict.class_covs)):
            cov = ds_cov - ds_mean.reshape(-1, 1) @ ds_mean.reshape(1, -1)

            # jit = 3e-3
            # jit = 5e-2
            jit = stats_dict.jits[csh_i]
            # jit = 0
            global_cov = cov
            cov = cov + jit * torch.eye(cov.shape[0], device = cov.device) * torch.trace(cov)/cov.shape[0]
            
            # cov = cov + jit * torch.eye(cov.shape[0], device = cov.device) #* torch.trace(cov)/cov.shape[0]
            ds_inverses_model.append(torch.linalg.inv(cov))
            ds_logdets_model.append(torch.linalg.slogdet(cov)[1])
            
            
            
            cov = ds_cov

            # jit = 3e-3
            # jit = 5e-2
            class_jits = torch.tensor(stats_dict.class_jits[csh_i]).to(cov.device)
            # jit = 0
            cov = cov + jit * torch.eye(cov.shape[0], device = cov.device)
            ds_inverses2_model.append(torch.linalg.inv(cov.to(torch.float64)).to(torch.float32).detach())
            
            
            class_covs = ds_class_cov - ds_class_mean[:, :, None] * ds_class_mean[:, None, :] #C D1 D2, C D1
            
            # globa_corr = global_cov/(torch.sqrt(torch.diag(global_cov) + 1e-6)[:, None] * torch.sqrt(torch.diag(global_cov) + 1e-6)[None, :])
            # class_diags = torch.sqrt(torch.diagonal(class_covs, dim1 = 1, dim2 = 2) + 1e-6) #C D
            # class_covs = globa_corr * class_diags[:, None, :] * class_diags[:, :, None]
            
            
            class_covs = class_covs + class_jits[:, None, None] * torch.eye(class_covs.shape[1], device = cov.device)[None] + 1e-6 * torch.eye(class_covs.shape[1], device = cov.device)[None]
            ds_class_inverses_model.append(torch.linalg.inv(class_covs).detach())
            ds_class_logdets_model.append(torch.linalg.slogdet(class_covs)[1].detach())
            
            
        stats_dict.global_invs = ds_inverses_model
        stats_dict.global_logdets = ds_logdets_model
        
        stats_dict.class_invs = ds_class_inverses_model
        stats_dict.class_logdets = ds_class_logdets_model

    
def synthesize_images(fwd_models, fwd_stats, bwd_models, bwd_stats, randomize_mode = 'per_batch', n_classes = 200, n_per_class = 50, batch_size = 200, image_res = [3, 64, 64], device = 'cuda', 
                   steps_per_batch = 1000, do_iterative = True, aug = nn.Identity(), kl_loss = True, cache_inverses = False, first_layer_scale = 1.0, raw_image_mean = None, raw_image_std = None, 
                   oob_coef = 10.0, distilled_save_path = './distilled_images/', ema_reset = 100000, ema_min_value = 0.0, apply_ema_min_when_saving = False, start_lr = 0.25, 
                   random_shift_amount = 32, ipc_offset = 0, target_class = 0,
                    target_resolution = 224, class_mean_coef = 0.1, selected_classes = [], dbn_coef = 0.0, train_aug_groups = 1, temp = 1.0):
    
    print("Synthesizing Images with:")
    print(f"n_fwd_models: {len(fwd_models)}")
    print(f"n_bwd_models: {len(bwd_models)}")
    print(f"do_iterative: {do_iterative}")
    print(f"kl_loss: {kl_loss}")
    print(f"randomize_mode: {randomize_mode}")
    print(f"n_classes: {n_classes}")
    print(f"n_per_class: {n_per_class}")
    print(f"steps_per_batch: {steps_per_batch}")
    print(f"batch_size: {batch_size}")
    print(f"cache_inverses: {cache_inverses}")
    print(f"first_layer_scale: {first_layer_scale}")
    print(f"class_mean_coef: {class_mean_coef}")
    print()
    print(f"ema_reset: {ema_reset}")
    print(f"ema_min_value: {ema_min_value}")
    print(f"apply_ema_min_when_saving: {apply_ema_min_when_saving}")
    print(f"start_lr: {start_lr}")
    print(f"random_shift_amount: {random_shift_amount}")
    print(f"ipc_offset: {ipc_offset}")
    print(f"train_aug_groups: {train_aug_groups}")
    print(f"temperature: {temp}")
    print(f"dbn_coef: {dbn_coef}")
    print()

    #prepare models
    fwd_cov_stat_hooks_all = []
    for m_i in range(len(fwd_models)):
        model = fwd_models[m_i]
        model.eval()
        
        cov_stat_hooks_model = get_cov_stat_hooks(model, fwd_stats[m_i].proj_config)
        fwd_cov_stat_hooks_all.append(cov_stat_hooks_model)
        
        get_cached_inverse_info(fwd_stats[m_i].data_stats)
        
    bwd_cov_stat_hooks_all = []
    for m_i in range(len(bwd_models)):
        model = bwd_models[m_i]
        model.eval()
        
        cov_stat_hooks_model = get_cov_stat_hooks(model, bwd_stats[m_i].proj_config)
        bwd_cov_stat_hooks_all.append(cov_stat_hooks_model)


    n_fwd_models = len(fwd_models)
    n_bwd_models = len(bwd_models)

    #prepare EMA means and covs
    fwd_ema_stats_all = []
    for m_i in range(n_fwd_models):
        # ema_means_all.append([])
        fwd_ema_stats_all.append([])
        
    bwd_ema_stats_all = []
    for m_i in range(n_bwd_models):
        # ema_means_all.append([])
        bwd_ema_stats_all.append([])
        


    batch_index = 0

    for ipc_id in range(0, n_per_class, batch_size):
        # target_labels = torch.tensor([target_class for i in range(batch_size)], dtype = torch.long)
        # target_labels = torch.tensor([target_class for i in range(batch_size)], dtype = torch.long)
        target_labels = torch.tensor([selected_classes[i%len(selected_classes)] for i in range(max(len(selected_classes), batch_size))], dtype = torch.long)
        for bi_start in range(0, batch_size, batch_size):
            bi_end = min(n_per_class, bi_start + batch_size)
            n_images_batch = bi_end - bi_start
            
            target_labels_batch = target_labels[bi_start: bi_end].to(device)
            
            # distilled_images_batch = torch.randn((n_images_batch, *image_res), requires_grad=True, device=device,dtype=torch.float)
            # print(image_res)
            
            # distilled_images_batch = torch.randn((n_images_batch, 3, 32, 32), requires_grad=True, device=device,dtype=torch.float)
            distilled_images_batch = torch.randn((n_images_batch, 3, target_resolution, target_resolution), requires_grad=True, device=device,dtype=torch.float)
            # class_means_batch = sample_means(ds_means_all, ds_covs_all, class_means_all, class_vars_all, jits, target_labels_batch)

            
            # optimizer = optim.Adam([distilled_images_batch], lr=1e-2, eps = 1e-8)
            optimizer = optim.Adam([distilled_images_batch], lr=start_lr, betas=[0.5, 0.9], eps = 1e-8)
            # start_lr = 8.0
            # optimizer = SGLD([distilled_images_batch], lr=start_lr, momentum = 0.9, temperature = 0.0002)
            
            lr_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer, steps_per_batch + 1, eta_min=0.00)

            ema_coef = 1/((batch_index%ema_reset) + 1) if do_iterative else 1.0
            ema_coef_opt = max(ema_coef, ema_min_value)
            ema_coef = ema_coef_opt if apply_ema_min_when_saving else ema_coef
            
            print("EMA EMA EMA EMA")
            print(ema_coef)

            pbar = tqdm(total=steps_per_batch)

            arr = list(range(n_fwd_models))
            subset_size = np.random.randint(1, n_fwd_models + 1)
            # subset_size = n_fwd_models
            # subset = np.random.choice(arr, size = subset_size, replace = False)
            # print('subset!')
            # print(subset)
            n_models_valid = 1
            for distill_step in range(steps_per_batch):
                
                if randomize_mode == 'per_batch':
                    selected_model_index = ipc_id%n_fwd_models
                else:
                    selected_model_index = distill_step%n_fwd_models
                
                                
                    
                selected_grad_model_index = distill_step%n_bwd_models

                optimizer.zero_grad()

                min_scale = 1 - (1 - 0.3) * distill_step/steps_per_batch
                oob_coef_annealed = (distill_step/steps_per_batch) * oob_coef

                aug = transforms.Compose([
                        # transforms.RandomResizedCrop(distilled_images_batch.shape[-1], scale = (min_scale, 1.0)),
                        transforms.RandomResizedCrop(target_resolution, scale = (min_scale, 1.0)),
                        transforms.RandomHorizontalFlip(),
                ])
                    

                loss = compute_loss_from_images(fwd_models[selected_model_index].to(device), distilled_images_batch, target_labels_batch, fwd_cov_stat_hooks_all[selected_model_index],
                                                fwd_stats[selected_model_index].data_stats.global_means, fwd_stats[selected_model_index].data_stats.global_covs, ema_stats = fwd_ema_stats_all[selected_model_index],
                                                ema_coef = ema_coef_opt, aug = aug, kl_loss = kl_loss, cache_inverses = cache_inverses, target_invs = fwd_stats[selected_model_index].data_stats.global_invs, target_logdets = fwd_stats[selected_model_index].data_stats.global_logdets, first_layer_scale = first_layer_scale, 
                                                oob_coef = oob_coef_annealed, raw_image_mean = raw_image_mean, raw_image_std = raw_image_std, jits = fwd_stats[selected_model_index].data_stats.jits, train_aug_groups = train_aug_groups, class_means = fwd_stats[selected_model_index].data_stats.class_means, class_mean_coef = class_mean_coef, 
                                                ll_grad_model = bwd_models[selected_grad_model_index].to(device), selected_classes = selected_classes,
                                                step_alpha = distill_step/steps_per_batch, class_covs = fwd_stats[selected_model_index].data_stats.class_covs, class_invs = fwd_stats[selected_model_index].data_stats.class_invs, class_logdets = fwd_stats[selected_model_index].data_stats.class_logdets, temp = temp, 
                                                dbn_coef = dbn_coef, cov_stat_hooks_grad = bwd_cov_stat_hooks_all[selected_grad_model_index], class_jits = fwd_stats[selected_model_index].data_stats.class_jits,
                                                class_include_indices = fwd_stats[selected_model_index].proj_config.class_include_indices,
                                                proj_config = fwd_stats[selected_model_index].proj_config, proj_config_bwd = bwd_stats[selected_grad_model_index].proj_config)


                pbar.set_description(f'batch: {batch_index+1}/{int(math.ceil(n_per_class/batch_size))}, step: {distill_step}, loss: {loss.item()}')
                pbar.update(1)

                loss.backward()
                optimizer.step()
                lr_schedule.step()
                
                clear_cov_stat_hooks(fwd_cov_stat_hooks_all[selected_model_index])
                if dbn_coef > 0.0:
                    clear_cov_stat_hooks(bwd_cov_stat_hooks_all[selected_grad_model_index])
                
            soft_labels = None

                
            # wandb.log({"loss": loss.item()})
            
            fwd_ema_stats_all = compute_ema_stats_from_model_list(fwd_models, fwd_cov_stat_hooks_all , fwd_stats, fwd_ema_stats_all, aug, distilled_images_batch, ema_coef, selected_classes, target_labels_batch)
            
            if dbn_coef > 0.0:
                bwd_ema_stats_all = compute_ema_stats_from_model_list(bwd_models, bwd_cov_stat_hooks_all , bwd_stats, bwd_ema_stats_all, aug, distilled_images_batch, ema_coef, selected_classes, target_labels_batch, bwd = True)
                    
            

            # distilled_images_batch = dct.idct_2d(distilled_idctmages_batch)
            images_raw = (distilled_images_batch * raw_image_std[None, :, None, None]) + raw_image_mean[None, :, None, None]
            images_raw = images_raw.detach().cpu().numpy()
            save_images(distilled_save_path, images_raw, one_hot(target_labels_batch.detach().cpu().numpy(), n_classes), soft_labels, ipc_id//len(selected_classes) + ipc_offset, n_classes, selected_classes)

            batch_index += 1
            print()

    for m_i, model in enumerate(fwd_models):
        for p_i, cov_stat_hook in enumerate(fwd_cov_stat_hooks_all[m_i]):
            cov_stat_hook.close()
            
    for m_i, model in enumerate(bwd_models):
        for p_i, cov_stat_hook in enumerate(fwd_cov_stat_hooks_all[m_i]):
            cov_stat_hook.close()

    return


def compute_ema_stats_from_model_list(fwd_models, fwd_cov_stat_hooks_all, fwd_stats, fwd_ema_stats_all, aug, distilled_images_batch, ema_coef, selected_classes, labels, bwd = False):
    with torch.autograd.set_grad_enabled(False):
        for m_i in range(len(fwd_models)):
            model = fwd_models[m_i].to('cuda')
            n_forward = 10
            
            model_stats = []

            for fp_i in range(n_forward):
                #Do we apply aug here??
                
                with torch.autograd.set_grad_enabled(bwd):
                    images_aug = aug(distilled_images_batch)
                    output = model(images_aug)
                    
                    if bwd:
                        true_label_loss = F.cross_entropy(output, labels, reduction = 'sum')
                        # g_whiz = torch.autograd.grad(true_label_loss, [images_aug], create_graph = True)
                        g_whiz = torch.autograd.grad(true_label_loss, [images_aug])
                        scale_cov_hooks_bwd(fwd_cov_stat_hooks_all[m_i], fwd_stats[m_i].proj_config)
            
                    else:
                        scale_cov_hooks(fwd_cov_stat_hooks_all[m_i], fwd_stats[m_i].proj_config)    

                model_stats.append(get_stats_from_hook_multi(fwd_cov_stat_hooks_all[m_i], len(selected_classes), bwd = bwd, detach = True))
                
            model.to('cpu')
                
                
            averaged_model_stats = average_multi_stats(model_stats)
                    
                    
            if len(fwd_ema_stats_all[m_i]) == 0:
                fwd_ema_stats_all[m_i] = averaged_model_stats
            else:
                fwd_ema_stats_all[m_i] = convert_to_ema_list(averaged_model_stats, fwd_ema_stats_all[m_i], ema_coef)
                
            
            clear_cov_stat_hooks(fwd_cov_stat_hooks_all[m_i])
    
    
    return fwd_ema_stats_all


def relabel_images(models, images, device = 'cuda', n_classes = 200, aug = nn.Identity(), batch_size = 200, n_forward = 20):
    for model in models:
        model.eval()

    n_total = images.shape[0]
    labels_all = np.zeros([n_total, n_classes])
    
    with torch.no_grad():
        for m_i in range(len(models)):
            model = models[m_i].to(device)
            labels_model = np.zeros([images.shape[0], n_classes])
            for bi_start in tqdm(range(0, n_total, batch_size)):
                bi_end = min(n_total, bi_start + batch_size)
                n_images_batch = bi_end - bi_start
                for nf_i in range(n_forward):
                    distilled_images_batch = torch.tensor(images[bi_start: bi_end]).to(device)

                    distilled_images_batch_aug = aug(distilled_images_batch)
                    labels_batch = model(distilled_images_batch_aug).detach().cpu().numpy()
                    
                    labels_model[bi_start:bi_end] += labels_batch/n_forward
            labels_all += labels_model
    labels_all = labels_all/len(models)
    return labels_all 



def get_cosine_crap_loss(g_real, g_fake):
    #C D
    gr_norm = g_real * torch.rsqrt(torch.sum(g_real**2, 1, keepdim = True) + 1e-6) #C D
    gf_norm = g_fake * torch.rsqrt(torch.sum(g_fake**2, 1, keepdim = True) + 1e-6)
    
    loss = torch.mean(1 - torch.sum(gr_norm * gf_norm, 1))
    return loss

def compute_loss_from_images(model, images, labels, cov_stat_hooks, target_means, target_covs, ema_stats = [], ema_coef = 1.0, aug = nn.Identity(), 
    kl_loss = True, cache_inverses = False, target_invs = None, target_logdets = None, first_layer_scale = 1.0, raw_image_mean = None, raw_image_std = None, oob_coef = 10.0, 
    jits = [], train_aug_groups = 1, class_means = None, class_mean_coef = 0.1, 
    ll_grad_model = None, selected_classes = [], step_alpha = 0.0,
    class_covs = [], class_invs = [], class_logdets = [], temp = 1.0,  dbn_coef = 0.0, cov_stat_hooks_grad = None, 
                                                class_jits = [], class_include_indices = [], proj_config = None, proj_config_bwd = None):
    # images = dct.idct_2d(images)
    
    images_aug = aug(images)

    if train_aug_groups > 1:
        images_aug = [images_aug]
        # print('GORUPING!!')
        for q in range(train_aug_groups - 1):
            images_aug.append(aug(images))
        images_aug = torch.concatenate(images_aug, 0)

    y_oh = nn.functional.one_hot(labels, class_means[0].shape[0])
    
    
    if ll_grad_model is None:    
        outputs, int_embeddings, embeddings = model(images_aug, embeddings_and_out = True)
        outputs_grad = outputs
    else:
        outputs = model(images_aug)
        # print("BONG HNERD")
        outputs_grad, int_embeddings, embeddings = ll_grad_model(images_aug, embeddings_and_out = True)
    
    scale_cov_hooks(cov_stat_hooks, proj_config)
    
        
    T = 1
        
    
    if dbn_coef > 0:
        true_label_loss = F.cross_entropy(outputs_grad/T, torch.argmax(y_oh, 1), reduction = 'sum')
        g_whiz = torch.autograd.grad(true_label_loss, [images_aug], create_graph = True)
        
        scale_cov_hooks_bwd(cov_stat_hooks_grad, proj_config_bwd)
        
    n_classes_to_synth = len(selected_classes)
    
    dist_loss = 0
    

    outputs = outputs#.to('cuda:0')
    

    for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
        batch_stats = get_stats_from_hook(cov_stat_hook, len(selected_classes))
        
        if len(ema_stats) == 0:
            smoothed_stats = batch_stats
        else:
            smoothed_stats = convert_to_ema(batch_stats, ema_stats[csh_i], ema_coef)
        smoothed_mean, smoothed_cov, smoothed_class_mean, smoothed_class_cov = smoothed_stats
        


        if kl_loss:
            if cache_inverses:
                fac = (1-step_alpha**3)
                
                dist_loss += first_layer_scale * get_kl_div_cached_inverse(smoothed_mean, target_means[csh_i], smoothed_cov, target_invs[csh_i], target_logdets[csh_i], jits[csh_i]).to(outputs.device)
                
                
                # print(class_include_indices)
                if csh_i in class_include_indices:
                    dist_loss += class_mean_coef * first_layer_scale * get_kl_div_cached_inverse_multi(smoothed_class_mean, class_means[csh_i][[selected_classes]], smoothed_class_cov, class_invs[csh_i][[selected_classes]], class_logdets[csh_i][[selected_classes]], class_jits[csh_i][selected_classes[0]], temp = temp).to(outputs.device)
                
     
            else:
                dist_loss += first_layer_scale * get_kl_div(smoothed_mean, target_means[csh_i], smoothed_cov, target_covs[csh_i]).to(outputs.device)
        else:
            dist_loss += first_layer_scale * get_l2_bn_loss(smoothed_mean, target_means[csh_i], smoothed_cov, target_covs[csh_i]).to(outputs.device)
        first_layer_scale = 1.0
        
    

    images_raw = (images * raw_image_std[None, :, None, None]) + raw_image_mean[None, :, None, None]
    oob_loss = torch.sum(torch.nn.functional.relu(images_raw - 1.0) ** 2) + torch.sum(torch.nn.functional.relu(-1 * images_raw)**2)
    oob_loss = oob_loss/images_raw.shape[0]
    

    return dist_loss + oob_coef * oob_loss



def get_stats_from_hook(cov_stat_hook, n_classes, bwd = False, detach = False):
    
    
    if not bwd:
        batch_size = cov_stat_hook.mean_unreduced_class.shape[0]
        dim_class = cov_stat_hook.mean_unreduced_class.shape[1]
        
        
        res = [
            cov_stat_hook.mean,
            cov_stat_hook.cov,
            cov_stat_hook.mean_unreduced_class.reshape(batch_size//n_classes, n_classes, -1).mean(0),
            cov_stat_hook.cov_unreduced_class.reshape(batch_size//n_classes, n_classes, dim_class, dim_class).mean(0),
        ]
        
        if detach:
            res = [x.detach() for x in res]
            
        return res
        
    else:
        batch_size = cov_stat_hook.d_shift.shape[0]
        dim_class = cov_stat_hook.d_shift.shape[1]
        
        res = [
            cov_stat_hook.d_shift.reshape(batch_size//n_classes, n_classes, -1).mean(0),
            cov_stat_hook.d_scale.reshape(batch_size//n_classes, n_classes, -1).mean(0)
        ]    
        
        if detach:
            res = [x.detach() for x in res]
            
        return res
        
        
def clear_cov_stat_hook(cov_stat_hook):
    attributes = [
        'mean_unreduced_unclipped',
        'cov_unreduced_unclipped',
        'mean_unreduced_unclipped_class',
        'cov_unreduced_unclipped_class',
        'mean_unreduced',
        'cov_unreduced',
        'mean_unreduced_class',
        'cov_unreduced_class',
        'mean',
        'cov'
    ]
    
    for attribute in attributes:
        if hasattr(cov_stat_hook, attribute):
            setattr(cov_stat_hook, attribute, None)
    
def clear_cov_stat_hooks(cov_stat_hooks):
    for csh in cov_stat_hooks:
        clear_cov_stat_hook(csh)
    
def get_stats_from_hook_multi(cov_stat_hooks, n_classes, bwd = False, detach = False):
    res = []
    
    
    for csh in cov_stat_hooks:
        res.append(get_stats_from_hook(csh, n_classes, bwd = bwd, detach = detach))
    
    return res

def convert_to_ema(batch_stats, ema_stats, ema_coef):
    res = []
    
    for batch_stat, ema_stat in zip(batch_stats, ema_stats):
        res.append(ema_coef * batch_stat + (1.0-ema_coef)*ema_stat)
        
    return res

def convert_to_ema_list(batch_stats, ema_stats, ema_coef):
    res = []
    
    for csh_i in range(len(list(batch_stats))):
        res.append(convert_to_ema(batch_stats[csh_i], ema_stats[csh_i], ema_coef))
        
    return res

def average_multi_stats(stats_list):
    res = []
    
    # stats_list[model][csh_i][stat]
    
    with torch.no_grad():
        for csh_i in range(len(stats_list[0])):
            res.append([])
            for s in range(len(stats_list[0][csh_i])):
                averaged_stat = torch.stack([stats_list[m_i][csh_i][s] for m_i in range(len(stats_list))]).mean(0)
                res[-1].append(averaged_stat)
    
    return res
        
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(-1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
    
    

def get_kl_div_cached_inverse_multi(m1, m2, v1, v2_inv, v2_logdet, jit: float, temp = 1.0, grouped = False):
    n_images = m1.shape[0]
    n_classes = m2.shape[0]
    
    if not grouped:
        m1 = m1.reshape(n_images//n_classes, n_classes, -1).mean(0) #C D
        v1 = v1.reshape(n_images//n_classes, n_classes, v1.shape[-1], v1.shape[-1]).mean(0) #C D D
    
    v1 = v1.type(torch.float64)
    v2_inv = v2_inv.type(torch.float64)
    m1 = m1.type(torch.float64)
    m2 = m2.type(torch.float64)

    v1 = v1 - m1[:, :, None] * m1[:, None, :]

    # jit = 3e-3
    # jit = 0.12
    
    v1 = v1 + jit * torch.eye(v1.shape[1], device = v1.device)[None] #* torch.mean(torch.diagonal(v1, dim1 = 1, dim2 = 2), -1)[:, None, None]
    logdet = v2_logdet - torch.linalg.slogdet(v1)[1] #C
    d = m1.shape[1]
    # trace = torch.trace(v1 @ v2_inv)
    trace = einops.einsum(v1, v2_inv, 'c d1 d2, c d2 d3 -> c d1 d3')
    trace = torch.diagonal(trace, dim1 = 1, dim2 = 2).sum(1) #c d -> c
    # mean = (m1- m2).reshape(1, -1)  @ v2_inv @ (m1- m2).reshape(-1, 1)
    mean = einops.einsum(v2_inv, m1-m2, 'c d1 d2, c d2 -> c d1')
    mean = einops.einsum(m1-m2, mean, 'c d1, c d1 -> c')/temp
    
    # mean = mean[0,0]

    kl = logdet - d + trace + mean
    # kl = mean
    kl = 0.5 * kl
    kl = kl.type(torch.float32)
    return kl.mean(0)

@torch.jit.script
def get_kl_div_cached_inverse(m1, m2, v1, v2_inv, v2_logdet, jit: float):
    v1 = v1.type(torch.float64)
    v2_inv = v2_inv.type(torch.float64)
    m1 = m1.type(torch.float64)
    m2 = m2.type(torch.float64)

    v1 = v1 - m1.reshape(-1, 1) @ m1.reshape(1, -1)

    # jit = 3e-3
    # jit = 0.12
    
    v1 = v1 + jit * torch.eye(v1.shape[0], device = v1.device) * torch.trace(v1)/v1.shape[0]
    logdet = v2_logdet - torch.linalg.slogdet(v1)[1]
    d = m1.shape[0]
    trace = torch.trace(v1 @ v2_inv)
    mean = (m1- m2).reshape(1, -1)  @ v2_inv @ (m1- m2).reshape(-1, 1)
    mean = mean[0,0]

    kl = logdet - d + trace + mean
    kl = 0.5 * kl
    kl = kl.type(torch.float32)
    return kl


def get_kl_div(m1, m2, v1, v2, device = 'cuda'):
    v1 = v1.type(torch.float64)
    v2 = v2.type(torch.float64)
    m1 = m1.type(torch.float64)
    m2 = m2.type(torch.float64)

    v1 = v1 - m1.reshape(-1, 1) @ m1.reshape(1, -1)
    v2 = v2 - m2.reshape(-1, 1) @ m2.reshape(1, -1)

    jit = 3e-3
    v1 = v1 + jit * torch.eye(v1.shape[0], device = v1.device) * torch.trace(v1)/v1.shape[0]
    v2 = v2 + jit * torch.eye(v2.shape[0], device = v2.device) * torch.trace(v2)/v2.shape[0]
    logdet = torch.linalg.slogdet(v2)[1] - torch.linalg.slogdet(v1)[1]
    d = m1.shape[0]
    v2_inv = torch.linalg.inv(v2)
    trace = torch.trace(v1 @ v2_inv)
    mean = (m1- m2).reshape(1, -1)  @ v2_inv @ (m1- m2).reshape(-1, 1)
    mean = mean[0,0]

    kl = logdet - d + trace + mean
    kl = 0.5 * kl
    kl = kl.type(torch.float32)
    return kl

def get_l2_bn_loss(m1, m2, v1, v2, device = 'cuda'):
    v1 = v1.type(torch.float64)
    v2 = v2.type(torch.float64)
    m1 = m1.type(torch.float64)
    m2 = m2.type(torch.float64)

    v1 = v1 - m1.reshape(-1, 1) @ m1.reshape(1, -1)
    v2 = v2 - m2.reshape(-1, 1) @ m2.reshape(1, -1)

    v1 = torch.diag(v1)
    v2 = torch.diag(v2)

    kl = torch.sqrt(torch.sum((v1-v2)**2)) + torch.sqrt(torch.sum((m1-m2)**2))
    return kl

def get_stats_from_path(stage, distilled_save_path):
    stats_dict_path = os.path.expanduser(distilled_save_path +'/proj_info/stats_stage_{}.pkl'.format(stage))
    
    stats_dict = pickle.load(open(stats_dict_path, 'rb'))
    
    model_name = stats_dict.model_name
    activation = stats_dict.activation
    source_n_per_class = stats_dict.source_n_per_class
    
    activation_fn = nn.SiLU
    
    if activation == 'relu':
        activation_fn = nn.ReLU
    
    if model_name == 'resnet18_silu':
        model = resnet18_silu(num_classes=source_n_per_class, activation = activation_fn)
    elif model_name == 'wrn28-8':
        model = WideResNet(28, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn28-10':
        model = WideResNet(28, source_n_per_class, 10, activation = activation_fn)
    elif model_name == 'wrn28-4':
        model = WideResNet(28, source_n_per_class, 4, activation = activation_fn)
    elif model_name == 'wrn22-8':
        model = WideResNet(22,source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn16-8':
        model = WideResNet(16, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn40-4':
        model = WideResNet(40, source_n_per_class, 4, activation = activation_fn)
    elif model_name == 'wrn10-4':
        model = WideResNet(10, source_n_per_class, 4, activation = activation_fn)

    model_save_path = stats_dict.model_save_path

    print(f'loading model from {model_save_path}')


    print(f'loading from {model_save_path}')
    state_dict = torch.load(model_save_path)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if 'module.' == k[:7]:
            name = k[7:] # remove `module.`
        else:
            name = k
        new_state_dict[name] = v    
    # print(model.state_dict.keys())
    model.load_state_dict(new_state_dict)
    model.eval()
    # model = nn.DataParallel(model)
    device = 'cpu'
    model = model.to(device)

    
    return model, stats_dict


def main(forward_stages = [0], bwd_stages= [], steps_per_batch = 1000, distilled_save_path = './distilled_images/', n_per_class = 50, kl_loss = True, do_iterative = True, cache_inverses = True, dataset_name = 'cifar10', 
         data_folder = 'lskdjfs', aug_in_cache = False, first_layer_scale = 10.0, oob_coef = 10.0, batch_size = 200, ema_reset = 1000000, ema_min_value = 0.0, apply_ema_min_when_saving = False, cache_batch_size = 128, 
         start_lr = 0.25, randomize_mode = 'per_iter', random_shift_amount = 32, ipc_offset = 0, target_class = 0,
         target_resolution = 224,
         class_mean_coef = 0.1, temperature = 1.0, dbn_coef = 0.0, stage = 0, selected_classes = None):

    # run = wandb.init(
    #     project = 'D3S distill clustered',
    #     config = locals()
    # )


    config = get_config()
    config.dataset.name = dataset_name
    device = 'cuda'
        
    if selected_classes is None:
        selected_classes = list(range(100))
        
        
    if type(selected_classes) is int:
        selected_classes = [selected_classes]

    if ',' in selected_classes:
        selected_classes = selected_classes.split(',')
        selected_classes = [int(s) for s in selected_classes]
        
        
    if type(forward_stages) is int:
        forward_stages = [forward_stages]

    if ',' in forward_stages:
        forward_stages = forward_stages.split(',')
        forward_stages = [int(s) for s in forward_stages]
    
    if type(bwd_stages) is int:
        bwd_stages = [bwd_stages]

    if ',' in bwd_stages:
        bwd_stages = bwd_stages.split(',')
        bwd_stages = [int(s) for s in bwd_stages]
        
    # print(selected_classes)

    print('preparing data')
    (ds_train, ds_test), preprocess_op, rev_preprocess_op, (raw_image_mean, raw_image_std) = get_dataset(config.dataset, apply_aug = aug_in_cache, data_folder = data_folder, batch_size = cache_batch_size, target_class = None, target_resolution = target_resolution)
    raw_image_mean = torch.from_numpy(raw_image_mean).type(torch.float32).view(-1).to(device)
    raw_image_std = torch.from_numpy(raw_image_std).type(torch.float32).view(-1).to(device)

    aug = transforms.Compose([
            transforms.RandomResizedCrop(config.dataset.img_shape[0]),
            transforms.RandomHorizontalFlip(),
    ])
    
    fwd_models = []
    fwd_stats = []

    for forward_stage in forward_stages:
        model, stats_dict = get_stats_from_path(forward_stage, distilled_save_path)
        
        fwd_models.append(model.to('cpu'))
        fwd_stats.append(stats_dict)
    
        
    bwd_models = []
    bwd_stats = []
        
    for bwd_stage in bwd_stages:
        model, stats_dict = get_stats_from_path(bwd_stage, distilled_save_path)
        
        bwd_models.append(model.to('cpu'))
        bwd_stats.append(stats_dict)
        

    if not os.path.exists(os.path.dirname(distilled_save_path)):
        os.makedirs(os.path.dirname(distilled_save_path))

    
    print('Synthesizing Images!')
    
    synthesize_images(fwd_models, fwd_stats, bwd_models, bwd_stats, steps_per_batch = steps_per_batch, n_classes = config.dataset.num_classes, n_per_class = n_per_class, aug = aug, do_iterative = do_iterative, kl_loss = kl_loss, 
                                                        cache_inverses = cache_inverses, image_res = [3, config.dataset.img_shape[0], config.dataset.img_shape[1]], first_layer_scale = first_layer_scale, raw_image_mean = raw_image_mean, raw_image_std = raw_image_std, oob_coef = oob_coef, 
                                                        distilled_save_path = distilled_save_path + f'/images/stage_{stage}', batch_size = batch_size, ema_reset = ema_reset, ema_min_value = ema_min_value, apply_ema_min_when_saving = apply_ema_min_when_saving, start_lr = start_lr, randomize_mode = randomize_mode, 
                                                        random_shift_amount = random_shift_amount, ipc_offset = ipc_offset, target_class = target_class, target_resolution = target_resolution, selected_classes = selected_classes, class_mean_coef = class_mean_coef, dbn_coef = dbn_coef, temp = temperature)
    
    return 

def one_hot(a, num_classes):
    return np.squeeze(np.eye(num_classes)[a.reshape(-1)])

def save_images(image_save_path, images, targets, soft_labels, ipc_id, n_classes, selected_classes):
    for id in range(images.shape[0]):
        if len(targets.shape) == 1:
            class_id = targets[id]
        else:
            class_id = targets[id].argmax()
            

        # save into separate folders
        dir_path = '{}/new{:03d}'.format(image_save_path, class_id)
        place_to_store = os.path.expanduser(dir_path +'/class{:03d}_id{:03d}.jpg'.format(class_id, ipc_id + id//len(selected_classes)))
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

        image_np = np.clip(images[id].transpose((1, 2, 0)), 0, 1)
        pil_image = Image.fromarray(np.around(image_np * 255).astype(np.uint8))
        pil_image.save(place_to_store, quality=100, subsampling=0)
        
        
        dir_path = '{}/../labels/new{:03d}'.format(image_save_path, class_id)
        place_to_store = os.path.expanduser(dir_path +'/class{:03d}_id{:03d}.pkl'.format(class_id, ipc_id + id//len(selected_classes)))
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
            
        # pickle.dump(soft_labels[id], open(place_to_store, 'wb'))

if __name__ == '__main__':
    fire.Fire(main)