import einops
import ml_collections
from dataloader import get_dataset
    
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from train_utils import train_one_epoch
import pickle
import numpy as np
from tqdm import tqdm

import os
import fire
from cov_stats import CovStatsHook, get_stats_conv, scale_cov_hooks, scale_cov_hooks_bwd #, matrix_pow
from modified_resnets import resnet18_silu
from collections import OrderedDict
# import wandb
from wide_resnets2 import WideResNet
import torch.nn.functional as F
import einops

from pytorch_metric_learning import losses


loss_func = losses.NTXentLoss()

def get_config():
    config = ml_collections.ConfigDict()
    config.train_log = 'train_log'
    config.train_img = 'train_img'
    config.resume = True

    config.img_size = None
    config.img_channels = None
    config.num_prototypes = None
    config.train_size = None

    config.dataset = ml_collections.ConfigDict()

    # Dataset
    config.dataset.name = 'imagenet'
    config.dataset.data_path = 'data/tensorflow_datasets'

    return config




def get_cov_stat_hooks(model, proj_config):
    cov_stat_hooks = []
    csh_i = 0
    bn_i = 0
    
    # m1_clips_all = 9999 * np.array(m1_clips_all)
    # m2_clips_all = 9999 * np.array(m2_clips_all)
    # b0 = 0
    
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):            
            n_channels = module.weight.shape[0]
            if bn_i in proj_config.include_indices:
                # proj_mat = proj_config.proj_mats[csh_i]
                # mmd_proj_mat = proj_config.mmd_proj_mats[csh_i]
                # bwd_proj_mat = proj_config.bwd_proj_mats[csh_i]
                cov_stat_hooks.append(CovStatsHook(module, main_proj_mat = proj_config.main_proj_mats[csh_i], class_proj_mat = proj_config.class_proj_mats[csh_i], 
                                                   bwd_sigmoid_scale = proj_config.bwd_sigmoid_scale, bwd_proj_mat = None, sigmoid_scale = proj_config.sigmoid_scale, 
                                                   class_sigmoid_scale = proj_config.class_sigmoid_scale, copy_main = proj_config.copy_main))
                
                csh_i += 1
            
            bn_i += 1
            
    return cov_stat_hooks
    


def compute_projections_and_clips_init(model, device = 'cuda', aug = nn.Identity(), include_indices = [], ev_ratio = 1.0, clip_power = 4.0, n_features_main = 512, n_features_class = 128, bwd_proj_dim = -1, bwd_sigmoid_scale = 1.0, 
                                       sigmoid_scale = 1.0, class_sigmoid_scale = 1.0, clip_ratio = 16.0, class_include_indices = [], class_clip_scale = 1.0,
                                       clip_type = 'global', cov_scale_multiplier = 1.0, jit_multiplier = 1.0, bwd_clip_scale = 16, use_class_scale = True):
    m1_dims = []
    m2_dims = []
    
    m1_clips = []
    m2_clips = []
    
    m1_clips_class = []
    m2_clips_class = []
    
    m1_dims_class = []
    m2_dims_class = []
    proj_mats = []
    
    m1_dims_bwd = []
    
    main_proj_mats = []
    class_proj_mats = []
    bn_i = 0
    
    
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            n_channels = module.weight.shape[0]
                
                    
            m1_dim = n_features_main
            m2_dim = n_features_main * (n_features_main + 1) /2
            
            m1_dim_class = n_features_class
            m2_dim_class = n_features_class * (n_features_class + 1) /2
            
            if n_features_main == -1:
                main_proj_mat = None
                m1_dim = n_channels
                m2_dim = n_channels * (n_channels + 1) /2
            else:    
                wonk = np.random.normal(size = [n_channels, n_features_main])/np.sqrt(n_channels)
                main_proj_mat_weight = (torch.tensor(wonk)).to(module.weight.device).type(torch.float32)
                main_proj_mat = main_proj_mat_weight
            
            if n_features_class == -1:
                class_proj_mat = None
            else:
                wonk2 = np.random.normal(size = [n_channels, n_features_class])/np.sqrt(n_channels)
                class_proj_mat_weight = (torch.tensor(wonk2)).to(module.weight.device).type(torch.float32)
                class_proj_mat = class_proj_mat_weight
                                
                
            if bn_i in include_indices:
                m1_clip = np.sqrt(m1_dim)/3
                # m2_clip = m1_dim/12
                m2_clip = m1_dim/9
                # m2_clip = m1_dim/4
                
                
                m1_clip = np.sqrt(m1_dim)/3
                m2_clip = m1_dim/9
                
                
                
                m1_dims.append(m1_dim)
                m2_dims.append(m2_dim)
                
                
                m1_dims_class.append(m1_dim_class if bn_i in class_include_indices else 0)
                m2_dims_class.append(m2_dim_class if bn_i in class_include_indices else 0)
                main_proj_mats.append(main_proj_mat)
                class_proj_mats.append(class_proj_mat)
                    
                m1_clips.append(m1_clip)
                m2_clips.append(m2_clip)
                
                m1_clips_class.append(np.sqrt(m1_dim_class)/3)
                m2_clips_class.append(m1_dim_class/9)
                
                m1_dims_bwd.append(2 * n_channels)


            bn_i += 1
    
    data_means = [None for i in range(len(m1_dims))]
    full_data_proj_mats = [None for i in range(len(m1_dims))]
    
    stage = 0
    
    if clip_type == 'global':
        class_scalar = class_clip_scale * (np.sum(m1_dims) + np.sum(m2_dims))/(np.sum(m1_dims_class) + np.sum(m2_dims_class))
    elif clip_type == 'hybrid_all':
        class_scalar = class_clip_scale * ((np.sum(m1_clips * np.sqrt(m1_dims))  + np.sum(m2_clips * np.sqrt(m2_dims)))/ (np.sum(m1_clips_class * np.sqrt(m1_dims_class)) + np.sum(m2_clips_class * np.sqrt(m2_dims_class))))**2
    
    if not use_class_scale:
        class_scalar = 1


    total_dim = np.sum(m1_dims) + cov_scale_multiplier * np.sum(m2_dims) + class_scalar * (np.sum(m1_dims_class) + cov_scale_multiplier * np.sum(m2_dims_class))
    total_clip = np.sqrt(total_dim)/clip_ratio
    
    
    if clip_type == 'hybrid_all':
        m1_clips_class = np.sqrt(class_scalar) * np.array(m1_clips_class)
        m2_clips_class = np.sqrt(class_scalar) * np.array(m2_clips_class)
    
    
    proj_config = ml_collections.ConfigDict()
    
    proj_config.m1_dims = m1_dims
    proj_config.m2_dims = m2_dims
    proj_config.m1_dims_class = m1_dims_class
    proj_config.m2_dims_class = m2_dims_class
    
    proj_config.m1_clips = np.array(m1_clips)
    proj_config.m2_clips = np.array(m2_clips)
    proj_config.m1_clips_class = np.array(m1_clips_class)
    proj_config.m2_clips_class = np.array(m2_clips_class)
    
    # proj_config.proj_mats = proj_mats
    proj_config.main_proj_mats = main_proj_mats
    proj_config.class_proj_mats = class_proj_mats
    proj_config.clip_power = clip_power
    proj_config.sigmoid_scale = sigmoid_scale
    proj_config.class_sigmoid_scale = class_sigmoid_scale
    proj_config.include_indices = include_indices
    proj_config.bwd_sigmoid_scale = bwd_sigmoid_scale
    proj_config.total_clip = total_clip
    proj_config.class_include_indices = class_include_indices
    
    proj_config.m1_dims_bwd = m1_dims_bwd
    
    proj_config.total_clip_bwd = np.sqrt(np.sum(m1_dims_bwd))/bwd_clip_scale
    
    proj_config.clip_type = clip_type
    proj_config.jit_multiplier = jit_multiplier
    
    proj_config.class_scalar = class_scalar
    proj_config.cov_scale_multiplier = cov_scale_multiplier
    
    proj_config.copy_main = (n_features_class == n_features_main)

    
    return proj_config

def add_noise_to_mean_cov(mean, cov, b0, noise_ratios, count, scale_ratio = 1.0, cov_scale_multiplier = 1.0):
    noise1 = torch.tensor(np.random.normal(size = mean.shape), device = mean.device)
    mean_noised = mean + b0 * (noise_ratios[0]/np.sqrt(scale_ratio)) * noise1/count
    
    noise2 = torch.tensor(np.random.normal(size = cov.shape), device = cov.device)
    cov_noise_raw = b0 * (noise_ratios[1]/np.sqrt(scale_ratio * cov_scale_multiplier)) * noise2/count
    
    noise_triu = torch.triu(cov_noise_raw)
    noise_up = torch.triu(cov_noise_raw, diagonal = 1)
    cov_noise_sym = noise_triu + noise_up.T


    #commenting out for noiseless experiment
    noised_m2 = cov + cov_noise_sym
    
    m1_sq = mean_noised[:, None] * mean_noised[None, :]
    # print(m1_sq.shape)
    noised_cov = noised_m2 - m1_sq
    L, Q = torch.linalg.eigh(noised_cov)
    # print(L)
    psd_cov = Q @ torch.diag(torch.clip(L, 0)) @ Q.T.conj()
    
    neg_ones = L[torch.where(L < 0)]
    jit = max(-1 * torch.min(L).item(), 3e-3)
    
    fixed_cov = psd_cov + m1_sq
    
    
    return mean_noised, fixed_cov, jit

def add_noise_to_mean(mean, b0, clip, count):
    noise1 = torch.tensor(np.random.normal(size = mean.shape), device = mean.device)
    mean_noised = mean + b0 * clip * noise1/count
    
    return mean_noised


def compute_full_ds_stats(model, loader, proj_config, device = 'cuda', aug = nn.Identity(), n_forward = 1, b0 = 0.0, ratio = 1.0, aug_groups = 1, target_resolution = 224, means_only = False, class_weights = -1, selected_classes = [], mmd_proj_mats = []):
    ratio = 1.0

    
    aug = transforms.Compose([
        # transforms.RandomResizedCrop(distilled_images_batch.shape[-1], scale = (min_scale, 1.0)),
        transforms.RandomResizedCrop(target_resolution, scale = (0.3, 1.0)),
        transforms.RandomHorizontalFlip(),
    ])

            
    cov_stat_hooks = get_cov_stat_hooks(model, proj_config)
    
    model.eval()

    data_means = []
    
    
    data_covs = []
    # data_covs2 = []
    data_class_covs = []
    
    data_class_means = []
    data_class_mmd_means = []
    
    data_class_dbnb_means = []
    data_class_dbnw_means = []
    data_class_dbnb_covs = []
    data_class_dbnw_covs = []
    
    data_ll_grads = None
    data_ll_f_mean = None
    data_ll_f_cov = None
        
    for i in range(n_forward):
        n_total = 0 

        for x, y in tqdm(loader):
            x_device = x.to(device)
            y_device = y.to(device)



            if aug_groups > 1:
                auged = [x_device]
                for a_i in range(aug_groups - 1):
                    auged.append(aug(x_device))
                
                x_device = torch.concatenate(auged, 0)
                
                y_device_lab = torch.concatenate([y_device for sdfs in range(aug_groups)], 0)
            else:
                y_device_lab = y_device
                
            
            # print(torch.min(x))
            # print(torch.max(x))

            with torch.set_grad_enabled(True):
                x_device.requires_grad_(True)
                outputs, int_embeddings, embeddings = model(x_device, embeddings_and_out = True)
                
                
                scale_cov_hooks(cov_stat_hooks, proj_config)
                
                if class_weights != -1:
                    output_weight = (1 - F.softmax(outputs, 1)[:, class_weights]).detach()
                    # n_total += torch.sum(output_weight).detach()
                    n_total += x.shape[0]
                else:
                    n_total += x.shape[0]
                    
                T = 1   
                    
                probs_value = F.softmax(outputs/T, 1)
                
                
                ll_grad = (y_device_lab - probs_value[:, :y_device.shape[1]]).T  @ embeddings
                
                loss = F.cross_entropy(outputs/T, torch.argmax(y_device_lab, 1), reduction = 'sum')
                
                g_whiz = torch.autograd.grad(loss, [x_device], create_graph = True)
                
                ll_f_g = torch.zeros([x.shape[0], 100], device = x_device.device)
                
                scale_cov_hooks_bwd(cov_stat_hooks, proj_config)
                
                
            
                
                    
                if len(data_means) == 0:
                    for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
                        # if class_weights == -1:
                        data_means.append(cov_stat_hook.mean.detach() * x.shape[0]/n_forward)
                        data_covs.append(cov_stat_hook.cov.detach() * x.shape[0]/n_forward)
                        # data_covs2.append(cov_stat_hook.cov2.detach() * x.shape[0]/n_forward)
                        data_class_means.append(y_device.T @ cov_stat_hook.mean_unreduced_class.detach()) #C N @ N D
                        data_class_covs.append(einops.einsum(y_device.T, cov_stat_hook.cov_unreduced_class.detach(), 'c n, n d1 d2 -> c d1 d2')) #C N @ N D
                        # data_class_mmd_means.append(y_device.T @ cov_stat_hook.mmd_features.detach()) #C N @ N D
                                                
                        data_class_dbnb_means.append(y_device.T @ cov_stat_hook.d_shift.detach()) 
                        data_class_dbnw_means.append(y_device.T @ cov_stat_hook.d_scale.detach()) 
                        # data_class_dbnb_covs.append(einops.einsum(y_device.T, cov_stat_hook.d_shift_cov.detach(), 'c n, n d1 d2 -> c d1 d2')) #C N @ N D
                        # data_class_dbnw_covs.append(einops.einsum(y_device.T, cov_stat_hook.d_scale_cov.detach(), 'c n, n d1 d2 -> c d1 d2')) #C N @ N D
                        # print(cov_stat_hook.d_scale)
                        # print(cov_stat_hook.d_shift)
                        
                        # else:
                        #     data_means.append((cov_stat_hook.mean_unreduced.detach() * output_weight[:, None]).sum(0)/n_forward)
                        #     data_covs.append((cov_stat_hook.cov_unreduced.detach() * output_weight[:, None, None]).sum(0)/n_forward)
                        #     data_covs2.append(cov_stat_hook.cov2.detach() * x.shape[0]/n_forward)
                            
                            # n_total += x.shape[0]s
                            
                    data_ll_grads = ll_grad.detach()
                    # data_ll_f_mean = torch.sum(ll_f_g, 0).detach()
                    data_ll_f_mean = (y_device.T @ ll_f_g).detach()  #C N, N D -> C D
                    data_ll_f_cov = einops.einsum(ll_f_g, ll_f_g, 'n d1, n d2 -> n d1 d2').sum(0).detach()
        
                else:
                    for csh_i, cov_stat_hook in enumerate(cov_stat_hooks):
                        # if class_weights == -1:
                        data_means[csh_i] += cov_stat_hook.mean.detach() * x.shape[0]/n_forward
                        data_covs[csh_i] += cov_stat_hook.cov.detach() * x.shape[0]/n_forward
                        # data_covs2[csh_i] += cov_stat_hook.cov2.detach() * x.shape[0]/n_forward
                        data_class_means[csh_i] += y_device.T @ cov_stat_hook.mean_unreduced_class.detach() #C N @ N D
                        data_class_covs[csh_i] += einops.einsum(y_device.T, cov_stat_hook.cov_unreduced_class.detach(), 'c n, n d1 d2 -> c d1 d2') #C N @ N D
                        
                        # data_class_mmd_means[csh_i] += y_device.T @ cov_stat_hook.mmd_features.detach() #C N @ N D
                        # data_class_dbn_means[csh_i] += y_device.T @ cov_stat_hook.d_combined.detach()
                        
                        
                        
                        data_class_dbnb_means[csh_i] += y_device.T @ cov_stat_hook.d_shift.detach()
                        data_class_dbnw_means[csh_i] += y_device.T @ cov_stat_hook.d_scale.detach()
                        # data_class_dbnb_covs[csh_i] += einops.einsum(y_device.T, cov_stat_hook.d_shift_cov.detach(), 'c n, n d1 d2 -> c d1 d2') #C N @ N D
                        # data_class_dbnw_covs[csh_i] += einops.einsum(y_device.T, cov_stat_hook.d_scale_cov.detach(), 'c n, n d1 d2 -> c d1 d2') #C N @ N D
                        # else:
                        #     data_means[csh_i] += (cov_stat_hook.mean_unreduced.detach() * output_weight[:, None]).sum(0)/n_forward
                        #     data_covs[csh_i] += (cov_stat_hook.cov_unreduced.detach() * output_weight[:, None, None]).sum(0)/n_forward
                        #     data_covs2[csh_i] += cov_stat_hook.cov2.detach() * x.shape[0]/n_forward
                            
                    data_ll_grads += ll_grad.detach()
                    # data_ll_f_mean += torch.sum(ll_f_g, 0).detach()
                    data_ll_f_mean +=  (y_device.T @ ll_f_g).detach()
                    data_ll_f_cov += einops.einsum(ll_f_g, ll_f_g, 'n d1, n d2 -> n d1 d2').sum(0).detach()
                            
                        

            

        print(n_total)
        
        
    data_ll_grads = data_ll_grads/(aug_groups * n_total)
    # data_ll_f_mean = data_ll_f_mean/n_total
    data_ll_f_mean = data_ll_f_mean/(aug_groups * n_total/len(selected_classes))
    
    data_ll_f_cov = data_ll_f_cov/(aug_groups * n_total)


    for csh_i in range(len(data_means)):
        data_means[csh_i] = data_means[csh_i]/n_total
        data_covs[csh_i] = data_covs[csh_i]/n_total
        # data_covs2[csh_i] = data_covs2[csh_i]/n_total
        data_class_means[csh_i] = data_class_means[csh_i]/(aug_groups * n_total/len(selected_classes))
        data_class_covs[csh_i] = data_class_covs[csh_i]/(aug_groups * n_total/len(selected_classes))
        # data_class_mmd_means[csh_i] = data_class_mmd_means[csh_i]/(aug_groups * n_total/len(selected_classes))
        
        data_class_dbnb_means[csh_i] = data_class_dbnb_means[csh_i]/(aug_groups * n_total/len(selected_classes))
        data_class_dbnw_means[csh_i] = data_class_dbnw_means[csh_i]/(aug_groups * n_total/len(selected_classes))
        # data_class_dbnb_covs[csh_i] = data_class_dbnb_covs[csh_i]/(aug_groups * n_total/len(selected_classes))
        # data_class_dbnw_covs[csh_i] = data_class_dbnw_covs[csh_i]/(aug_groups * n_total/len(selected_classes))
                
        # print(data_means[csh_i])
        # print(data_covs[csh_i])
        cov_stat_hooks[csh_i].close()
        
        # print(torch.mean(torch.diag(data_covs[csh_i])))

    m1_clips = np.array(proj_config.m1_clips)
    m1_dims = np.array(proj_config.m1_dims)
    m2_clips = np.array(proj_config.m2_clips)
    m2_dims = np.array(proj_config.m2_dims)
    
    m1_dims_class = np.array(proj_config.m1_dims_class)
    m2_dims_class = np.array(proj_config.m2_dims_class)
    m1_clips_class = np.array(proj_config.m1_clips_class)
    m2_clips_class = np.array(proj_config.m2_clips_class)
    
    print(n_total)
    # clip_sqrt_sum = np.sum(m1_clips * np.sqrt(m1_dims)) 
    # if not means_only:
    #     clip_sqrt_sum += np.sum(m2_clips * np.sqrt(m2_dims))
    # else:
    #     clip_sqrt_sum += np.sum(m1_clips * np.sqrt(m1_dims)) 
    print(m1_dims)
    print(m2_dims)
    
    if ratio != 1.0:
        for csh_i in range(len(data_means)):
            data_means[csh_i] = ratio * data_means[csh_i] + (1-ratio) * old_data_means[csh_i]
            data_covs[csh_i] = ratio * data_covs[csh_i] + (1-ratio) * old_data_covs[csh_i]

    jits = []
    
    print((n_total/len(selected_classes)))
    
    
    # total_dim = np.sum(m1_dims) + np.sum(m2_dims)
    # clip = np.sqrt(total_dim)/3
    
    
    # clip = proj_config.total_clip
    
    # b0 = 0
    
    clip_sqrt_sum = np.sum(m1_clips * np.sqrt(m1_dims))  + np.sum(m2_clips * np.sqrt(m2_dims)) + np.sum(m1_clips_class * np.sqrt(m1_dims_class)) + np.sum(m2_clips_class * np.sqrt(m2_dims_class))
                    
                    
                    


    class_jits = []
    for csh_i in range(len(data_means)):
        
        # noise_ratio = proj_config.total_clip if proj_config.
        if proj_config.clip_type == 'global':
            noise_ratios = (proj_config.total_clip, proj_config.total_clip)
            scale_ratio = proj_config.class_scalar
        elif proj_config.clip_type == 'hybrid_all':
            noise_ratios = (np.sqrt(m1_clips[csh_i] * clip_sqrt_sum/np.sqrt(m1_dims[csh_i])), np.sqrt(m2_clips[csh_i] * clip_sqrt_sum/np.sqrt(m2_dims[csh_i])))
            
            # scale_ratio = 1.0
            scale_ratio = proj_config.class_scalar
            
        
        
        mean_noised, cov_noised, jit = add_noise_to_mean_cov(data_means[csh_i], data_covs[csh_i], b0, noise_ratios, n_total, scale_ratio = 1.0, cov_scale_multiplier = proj_config.cov_scale_multiplier)
        
        data_means[csh_i] = mean_noised
        data_covs[csh_i] = cov_noised
        
        jits.append(proj_config.jit_multiplier * jit)
        
        layer_jits = []
        
        
        
        
        for c in range(data_class_means[csh_i].shape[0]):
            if proj_config.clip_type == 'global':
                noise_ratios = (proj_config.total_clip, proj_config.total_clip)
            elif proj_config.clip_type == 'hybrid_all':
                if m1_dims_class[csh_i] == 0:
                    noise_ratios = (0,0)
                else:
                    noise_ratios = (np.sqrt(m1_clips_class[csh_i] * clip_sqrt_sum/np.sqrt(m1_dims_class[csh_i])), np.sqrt(m2_clips_class[csh_i] * clip_sqrt_sum/np.sqrt(m2_dims_class[csh_i])))
            
            mean_noised, cov_noised, class_jit = add_noise_to_mean_cov(data_class_means[csh_i][c], data_class_covs[csh_i][c], b0, noise_ratios, n_total/len(selected_classes), scale_ratio, proj_config.cov_scale_multiplier)
            data_class_means[csh_i][c] = mean_noised
            data_class_covs[csh_i][c] = cov_noised
            
            layer_jits.append(proj_config.jit_multiplier * class_jit)
            
            data_class_dbnb_means[csh_i][c] = add_noise_to_mean(data_class_dbnb_means[csh_i][c], b0, proj_config.total_clip_bwd, n_total/len(selected_classes))
            data_class_dbnw_means[csh_i][c] = add_noise_to_mean(data_class_dbnw_means[csh_i][c], b0, proj_config.total_clip_bwd, n_total/len(selected_classes))
        
        # print(layer_jits)
        class_jits.append(np.array(layer_jits))
        
    
    # print(data_means)
    print("jits")
    # print(jits)
    
    
    
    # if means_only:
    #     return data_means, data_covs2, jits
    
    print(class_jits)
    
    data_measurements = ml_collections.ConfigDict()
    
    data_measurements.global_means = data_means
    data_measurements.global_covs = data_covs
    data_measurements.jits = jits
    data_measurements.class_jits = class_jits
    data_measurements.class_means = data_class_means
    data_measurements.class_covs = data_class_covs
    data_measurements.class_dbnb_means = data_class_dbnb_means
    data_measurements.class_dbnw_means = data_class_dbnw_means
    
    return data_measurements
    

# python compute_ds_stats.py --model_save_path ./trained_models/wrn22_8_silu/model_0.pt  --distilled_save_path ./../D3S_again/example_path/eps_8/n_models_1/feature_count_all_512/feature_count_class_512/layer_indices_16_17_18/clip_ratio_8/class_clip_ratio_1.0/cm_coef_30.0/sigmoid_scale_1.0/class_sigmoid_scale_1.0/bwd_sigmoid_scale_16.0/iter_mode_per_iter/clip_type_global/cov_scale_multiplier_1.0/jit_multiplier_2.0/activation_silu/use_gsam_False/ \
# --dataset_name cifar10 --source_n_per_class 1000 --activation silu --model_name wrn22-8 --target_resolution 32 \
# --b0 0.637  --sigmoid_scale 1.0 --clip_power 8.0  --selected_classes 0,1,2,3,4,5,6,7,8,9 \
# --n_features_class 512 --ll_grad_model_type wrn22-8 --stage 0 --bwd_sigmoid_scale 16.0 --clip_ratio 8 --class_include_indices 16,17,18

def main(model_save_path = './trained_models/model.pt', distilled_save_path = './distilled_images/', dataset_name = 'cifar10', data_folder = 'lskdjfs', aug_in_cache = False, skip_model_eval = False, cache_batch_size = 64, 
         source_n_per_class = 1000, b0 = 0.638, activation = 'silu', stages = [0], subsample = 1.0, aug_groups = 1, target_resolution = 224, model_name = 'resnet18_silu', first_layer_only = False, clip_power = 3.5, selected_classes = None,
         bwd_sigmoid_scale = 1.0, stage = 0, sigmoid_scale = 1.0, class_sigmoid_scale = 1.0, class_include_indices = None, n_features_main = 512, 
         n_features_class = 64, class_clip_scale = 1.0, clip_type = 'global', cov_scale_multiplier = 1.0, jit_multiplier = 1.0, bwd_clip_scale = 4.0, use_class_scale = True):


    stages = stages * 10
    
    
    # if activation == 'silu':
    activation_fn = nn.SiLU
    if activation =='relu':
        activation_fn = nn.ReLU

    config = get_config()
    config.dataset.name = dataset_name
    device = 'cuda'


    print('preparing data')
    (ds_train, ds_test), preprocess_op, rev_preprocess_op, (raw_image_mean, raw_image_std) = get_dataset(config.dataset, apply_aug = aug_in_cache, data_folder = data_folder, batch_size = cache_batch_size, target_class = None, target_resolution = target_resolution)
    raw_image_mean = torch.from_numpy(raw_image_mean).type(torch.float32).view(-1).to(device)
    raw_image_std = torch.from_numpy(raw_image_std).type(torch.float32).view(-1).to(device)
    

    if class_include_indices is None:
        class_include_indices = []
        
        
    if type(class_include_indices) is int:
        class_include_indices = [class_include_indices]

    if ',' in class_include_indices:
        class_include_indices = class_include_indices.split(',')
        class_include_indices = [int(s) for s in class_include_indices]
        

    trained_models = []
    
    include_indices = [i for i in range(60)]
    if first_layer_only:
        include_indices = [0]
        
        
    
    # for m_i, model_save_path in enumerate(model_save_paths):
    print("preparing model")
    # model = torchvision.models.get_model('resnet18', num_classes=source_n_per_class, weights = None)


    if model_name == 'resnet18_silu':
        model = resnet18_silu(num_classes=source_n_per_class, activation = activation_fn)
    elif model_name == 'wrn28-8':
        model = WideResNet(28, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn28-10':
        model = WideResNet(28, source_n_per_class, 10, activation = activation_fn)
    elif model_name == 'wrn28-4':
        model = WideResNet(28, source_n_per_class, 4, activation = activation_fn)
    elif model_name == 'wrn22-8':
        model = WideResNet(22,source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn16-8':
        model = WideResNet(16, source_n_per_class, 8, activation = activation_fn)
    elif model_name == 'wrn40-4':
        model = WideResNet(40, source_n_per_class, 4, activation = activation_fn)
    elif model_name == 'wrn10-4':
        model = WideResNet(10, source_n_per_class, 4, activation = activation_fn)


    # model = torchvision.models.get_model('resnet18', num_classes=source_n_per_class, weights = None)
    if dataset_name == 'tiny_imagenet':
        model.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
        model.maxpool = nn.Identity()

    print(f'loading from {model_save_path}')
    state_dict = torch.load(model_save_path)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if 'module.' == k[:7]:
            name = k[7:] # remove `module.`
        else:
            name = k
        new_state_dict[name] = v    
    # print(model.state_dict.keys())
    model.load_state_dict(new_state_dict)
    model.eval()
    # model = nn.DataParallel(model)
    model = model.to(device)

    if not skip_model_eval:
        print("evaluating loaded model")
        test_loss, test_acc, _ = train_one_epoch(model, ds_test, None, None, preprocess_op, train = False)
        print("loaded model acc: {}".format(test_acc))

    trained_models.append(model)
    
    


    (ds_train, _), _, _, (_, _) = get_dataset(config.dataset, apply_aug = aug_in_cache, data_folder = data_folder, batch_size = cache_batch_size, target_class = selected_classes, subsample = subsample, target_resolution = target_resolution)

        
    proj_config =  compute_projections_and_clips_init(model, include_indices = include_indices, clip_power = clip_power, n_features_main = n_features_main, n_features_class = n_features_class, bwd_proj_dim = -1, bwd_sigmoid_scale = bwd_sigmoid_scale, 
                                       sigmoid_scale = sigmoid_scale, class_sigmoid_scale = class_sigmoid_scale, class_include_indices = class_include_indices, class_clip_scale= class_clip_scale, 
                                       clip_type = clip_type, cov_scale_multiplier = cov_scale_multiplier, jit_multiplier = jit_multiplier, bwd_clip_scale = bwd_clip_scale, use_class_scale = use_class_scale)


    print('Computing Full data Stats')
        
    data_measurements = compute_full_ds_stats(model, ds_train, proj_config, n_forward = 5 if aug_in_cache else 1, b0 = b0, ratio = 1.0, aug_groups = aug_groups,
            target_resolution = target_resolution, selected_classes = selected_classes)
    

    full_dict = ml_collections.ConfigDict()
    full_dict.proj_config = proj_config
    full_dict.data_stats = data_measurements
    full_dict.model_save_path = model_save_path
    full_dict.model_name = model_name
    full_dict.source_n_per_class = model.fc.weight.shape[0]
    full_dict.activation = activation
        


    if not os.path.exists(os.path.dirname(distilled_save_path)):
        os.makedirs(os.path.dirname(distilled_save_path))

    place_to_store = os.path.expanduser(distilled_save_path +'/proj_info/stats_stage_{}.pkl'.format(stage))
    if not os.path.exists(os.path.dirname(place_to_store)):
        os.makedirs(os.path.dirname(place_to_store))


    pickle.dump(full_dict, open(place_to_store, 'wb'))

    # wandb.finish()
    return 

if __name__ == '__main__':
    fire.Fire(main)