import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
import random
import shutil
import math
import wandb
import h5py

from sklearn.cluster import KMeans
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, evaluate_synset, evaluate_synset_grl, get_time, ParamDiffAug 
from supplement import get_param_by_method, get_real_data, training_loop, wandb_init_project 


def main():
    # - Original arguments
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, help='DC/DSA')
    parser.add_argument('--dataset', type=str, help='dataset')
    parser.add_argument('--model', type=str, help='model')
    parser.add_argument('--ipc', type=int, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=20, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=None, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='noise', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')


    # - Noname arguments
    parser.add_argument('--eval_iter', type=int)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--checkpoint_file', type=str, default='None')
    parser.add_argument('--coreset', action='store_true')
    parser.add_argument('--coreset_type', type=str, default='whole_dataset')
    parser.add_argument('--embedding_weight', type=float, default=1.0)
    parser.add_argument('--domain_mask_init', type=float, default=1.0)
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--pseudo_domain_method', type=str, default='fft', choices=['fft', 'logvar'])
    parser.add_argument('--pseudo_domain_clustering', type=str, default='metric', choices=['metric', 'kmeans'])
    parser.add_argument('--nopd', type=int, default=4, help='number of pseudo domains')
    parser.add_argument('--normalize_method', type=str, default='softmax', choices=['softmax', 'sigmoid'])


    # - Wandb arguments
    parser.add_argument('--wandb', action='store_true')
    parser.add_argument('--wandb_project_name', type=str, default='hi')
    parser.add_argument('--wandb_group_name', type=str, default='None')


    # - Start
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    assert args.data_path is not None, 'Please provide the path to the dataset'
    wandb_save_iter = 50
    start_exp = 0
    args.method = args.method.upper() 
    args.pseudo_domain_method = args.pseudo_domain_method.upper()
    args.pseudo_domain_clustering = args.pseudo_domain_clustering.upper()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    assert args.data_path is not None, 'Please provide the path to the dataset'

    if not args.coreset:
        args.Iteration, args.eval_iter, args.lr_img, args.dsa_strategy = get_param_by_method(args) 
        args.outer_loop, args.inner_loop = get_loops(args.ipc)
        args.dsa_param = ParamDiffAug()
        args.dsa = False if args.dsa_strategy in ['none', 'None'] else True
        eval_it_pool = np.arange(0, args.Iteration+1, args.eval_iter).tolist() if args.eval_mode == 'S' or args.eval_mode == 'SS' else [args.Iteration] # The list of iterations when we evaluate models and record results.

    channel, im_size, num_classes, num_domains, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)

    if not os.path.exists(f'saved_data/{args.dataset}_data.h5'):
        os.makedirs('saved_data', exist_ok=True)
        if 'OH' in args.dataset or 'PACS' in args.dataset or 'VLCS' in args.dataset or 'DomainNet' in args.dataset:
            images_all = []
            labels_all = []
            domains_all = []

            images_all, labels_all, domains_all = zip(*dst_train)
            images_all = torch.stack(images_all, dim=0).to('cuda')
            labels_all = torch.tensor(labels_all, dtype=torch.long, device='cuda')
            domains_all = torch.tensor(domains_all, dtype=torch.long, device='cuda')

            for ch in range(channel):
                print(f'real images channel {ch}, mean = {torch.mean(images_all[:, ch]):.4f}, std = {torch.std(images_all[:, ch]):.4f}')

            print('Saving Starts')
            h5_file_path = f'saved_data/{args.dataset}_data.h5'
            with h5py.File(h5_file_path, 'w') as f:
                f.create_dataset('images_all', data=images_all.cpu().numpy())  # Move to CPU before saving
                f.create_dataset('labels_all', data=labels_all.cpu().numpy())
                f.create_dataset('domains_all', data=domains_all.cpu().numpy())

        else:
            images_all = []
            labels_all = []

            images_all, labels_all = zip(*dst_train)
            images_all = torch.stack(images_all, dim=0).to('cuda')
            labels_all = torch.tensor(labels_all, dtype=torch.long, device='cuda')

            for ch in range(channel):
                print(f'real images channel {ch}, mean = {torch.mean(images_all[:, ch]):.4f}, std = {torch.std(images_all[:, ch]):.4f}')

            print('Saving Starts')
            h5_file_path = f'saved_data/{args.dataset}_data.h5'
            with h5py.File(h5_file_path, 'w') as f:
                f.create_dataset('images_all', data=images_all.cpu().numpy())  # Move to CPU before saving
                f.create_dataset('labels_all', data=labels_all.cpu().numpy())

    for exp in range(start_exp, args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)

        # - For wandb
        args.wandb_group_name = f'{args.method}_{args.model}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}_{args.dataset}_{args.ipc}_{exp}' if args.method != 'coreset' else args.wandb_name
        wandb_run = wandb_init_project(args) if args.wandb else None


        # - Get real data (modified to save time)

        images_all, labels_all, indices_class, *temp = get_real_data(args, dst_train, num_classes, num_domains)

        if len(temp) > 0:
            domains_all, indices_domain = temp

        if args.pseudo_domain_method == 'FFT':
            if args.pseudo_domain_clustering == 'METRIC':
                if not os.path.exists(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt'):
                    pseudo_domain_lbl = []
                    for single_img in images_all:
                        fft = torch.fft.fft2(single_img)
                        fft = torch.fft.fftshift(fft, dim=(-2, -1))
                        x_mean, y_mean = np.ceil(single_img.shape[1]*0.09), np.ceil(single_img.shape[2]*0.09)
                        low_freq = fft[:, int(single_img.shape[1]//2-x_mean):int(single_img.shape[1]//2+x_mean), int(single_img.shape[2]//2-y_mean):int(single_img.shape[2]//2+y_mean)]
                        mean = torch.mean(torch.abs(low_freq))
                        pseudo_domain_lbl.append(mean)
                    pseudo_domain_lbl = torch.tensor(pseudo_domain_lbl)
                    temp_domain_idx = torch.argsort(pseudo_domain_lbl).numpy()
                    indices_pseudo_domain = [[] for d in range(args.nopd)]
                    for _domain_num in range(args.nopd):
                        indices_pseudo_domain[_domain_num] = temp_domain_idx[_domain_num*(len(images_all)//args.nopd):(_domain_num+1)*(len(images_all)//args.nopd)]
                    torch.save(indices_pseudo_domain, f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
                else:
                    indices_pseudo_domain = torch.load(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
            elif args.pseudo_domain_clustering == 'KMEANS':
                if not os.path.exists(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt'):
                    pseudo_domain_features = []
                    for single_img in images_all:
                        fft = torch.fft.fft2(single_img)
                        fft = torch.fft.fftshift(fft, dim=(-2, -1))
                        x_radius = int(np.ceil(single_img.shape[1] * 0.09))
                        y_radius = int(np.ceil(single_img.shape[2] * 0.09))
                        center_x = single_img.shape[1] // 2
                        center_y = single_img.shape[2] // 2
                        low_freq = fft[:, center_x - x_radius : center_x + x_radius,
                                          center_y - y_radius : center_y + y_radius]
                        feature_vector = torch.abs(low_freq).flatten()
                        pseudo_domain_features.append(feature_vector)
                    pseudo_domain_features = torch.stack(pseudo_domain_features)
                    fft_features_np = pseudo_domain_features.cpu().numpy()
                    kmeans = KMeans(n_clusters=args.nopd, random_state=0)
                    clusters = kmeans.fit_predict(fft_features_np)
                    indices_pseudo_domain = [[] for _ in range(args.nopd)]
                    for domain in range(args.nopd):
                        indices_pseudo_domain[domain] = np.where(clusters == domain)[0].tolist()
                    torch.save(indices_pseudo_domain, f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
                else:
                    indices_pseudo_domain = torch.load(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
        
        elif args.pseudo_domain_method == 'LOGVAR':
            if args.pseudo_domain_clustering == 'METRIC':
                if not os.path.exists(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt'):
                    temp_net = get_network(args, args.model, channel, num_classes, im_size, domain=True, num_domains=args.nopd).to(args.device)
                    temp_net.eval()
                    pseudo_domain_metrics_list = []
                    for single_img in images_all:
                        with torch.no_grad():
                            single_img = single_img.unsqueeze(0)  # shape: [1, C, H, W]
                            intermediate_feats = temp_net.extract_intermediate_features(single_img)
                            logvar_features_list = []
                            for key in sorted(intermediate_feats.keys()):
                                feat = intermediate_feats[key]  # shape: [1, channels, H, W]
                                log_var = torch.log(feat.var(dim=(2, 3)) + 1e-8)  # shape: [1, channels]
                                logvar_features_list.append(log_var)
                            logvar_features = torch.cat(logvar_features_list, dim=1)  # shape: [1, total_channels]
                            scalar_metric = logvar_features.mean(dim=1)  # shape: [1]
                            pseudo_domain_metrics_list.append(scalar_metric.item())
                    pseudo_domain_metric = torch.tensor(pseudo_domain_metrics_list)  # shape: [num_images]
                    temp_domain_idx = torch.argsort(pseudo_domain_metric).numpy()
                    indices_pseudo_domain = [[] for _ in range(args.nopd)]
                    num_images = len(pseudo_domain_metric)
                    group_size = num_images // args.nopd
                    for d in range(args.nopd):
                        start = d * group_size
                        end = num_images if d == args.nopd - 1 else (d + 1) * group_size
                        indices_pseudo_domain[d] = temp_domain_idx[start:end]
                    torch.save(indices_pseudo_domain, f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
                else:
                    indices_pseudo_domain = torch.load(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
            elif args.pseudo_domain_clustering == 'KMEANS':
                if not os.path.exists(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt'):
                    temp_net = get_network(args, args.model, channel, num_classes, im_size, domain=True, num_domains=args.nopd).to(args.device)
                    temp_net.eval()
                    pseudo_domain_features_list = []
                    for single_img in images_all:
                        with torch.no_grad():
                            single_img = single_img.unsqueeze(0)  # shape: [1, C, H, W]
                            intermediate_feats = temp_net.extract_intermediate_features(single_img)
                            logvar_features_list = []
                            for key in sorted(intermediate_feats.keys()):
                                feat = intermediate_feats[key]  # shape: [1, channels, H, W]
                                log_var = torch.log(feat.var(dim=(2, 3)) + 1e-8)  # shape: [1, channels]
                                logvar_features_list.append(log_var)
                            logvar_features = torch.cat(logvar_features_list, dim=1)  # shape: [1, total_channels]
                            pseudo_domain_features_list.append(logvar_features.squeeze(0))
                    pseudo_domain_features_all = torch.stack(pseudo_domain_features_list)  
                    features_np = pseudo_domain_features_all.cpu().numpy()
                    kmeans = KMeans(n_clusters=args.nopd, random_state=0)
                    clusters = kmeans.fit_predict(features_np)
                    indices_pseudo_domain = [[] for _ in range(args.nopd)]
                    for d in range(args.nopd):
                        indices_pseudo_domain[d] = np.where(clusters == d)[0].tolist()
                    torch.save(indices_pseudo_domain, f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')
                else:
                    indices_pseudo_domain = torch.load(f'saved_data/{args.dataset}_indices_pseudo_domain_{args.nopd}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}.pt')

        for c in range(num_classes):
            print('class c = %d: %d real images'%(c, len(indices_class[c])))
        for ch in range(channel):
            print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))


        def get_images(c, n, by_domain=False, balanced_class=False, pseudo_domain=False):
            if by_domain:
                if pseudo_domain:
                    if balanced_class:
                        temp_idx_shuffle = {}
                        for _temp_class in range(num_classes):
                            _x = np.random.permutation(list(set(indices_pseudo_domain[c]).intersection(indices_class[_temp_class])))
                            temp_idx_shuffle[_temp_class] = [_x, len(_x)]
                        smallest_length = min(item[1] for item in temp_idx_shuffle.values())
                        if smallest_length < (n//num_classes):
                            idx_shuffle = [item[0][:smallest_length] for item in temp_idx_shuffle.values()]
                        else:
                            idx_shuffle = [item[0][:(n//num_classes)] for item in temp_idx_shuffle.values()]
                        idx_shuffle = [element for sublist in idx_shuffle for element in sublist]
                    else:
                        idx_shuffle = np.random.permutation(indices_pseudo_domain[c])[:n]
                    return images_all[idx_shuffle]
                elif balanced_class:
                    assert len(temp) > 0, 'No domain indices'
                    temp_idx_shuffle = {}
                    for _temp_class in range(num_classes):
                        _x = np.random.permutation(list(set(indices_domain[c]).intersection(indices_class[_temp_class])))
                        temp_idx_shuffle[_temp_class] = [_x, len(_x)]
                    smallest_length = min(item[1] for item in temp_idx_shuffle.values())
                    if smallest_length < (n//num_classes):
                        idx_shuffle = [item[0][:smallest_length] for item in temp_idx_shuffle.values()]
                    else:
                        idx_shuffle = [item[0][:(n//num_classes)] for item in temp_idx_shuffle.values()]
                    idx_shuffle = [element for sublist in idx_shuffle for element in sublist]
                else:
                    idx_shuffle = np.random.permutation(indices_domain[c])[:n]
                return images_all[idx_shuffle]
            else:
                idx_shuffle = np.random.permutation(indices_class[c])[:n]
                return images_all[idx_shuffle]

        if args.coreset:
            if args.coreset_type == 'random':
                image_for_train = torch.zeros(num_classes*args.ipc, channel, im_size[0], im_size[1]).to(args.device)
                label_for_train = torch.zeros(num_classes*args.ipc, dtype=torch.long, requires_grad=False, device=args.device)
                for c in range(num_classes):
                    image_for_train[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc)
                    label_for_train[c*args.ipc:(c+1)*args.ipc] = c

            elif args.coreset_type == 'k-center':
                image_for_train = torch.zeros(num_classes*args.ipc, channel, im_size[0], im_size[1]).to(args.device)
                label_for_train = torch.zeros(num_classes*args.ipc, dtype=torch.long, requires_grad=False, device=args.device)
                net_for_feature = get_network(args.model, channel, num_classes, im_size).cpu()
                net_for_feature.load_state_dict(torch.load(f'results/WHOLE_DATASET_ConvNet/{args.dataset}/0/0_net_latest.pth'))
                net_for_feature.eval()
                with torch.no_grad():
                    print("# - Net trained on WHOLE DATASET Loaded !!")
                    for c in range(num_classes):
                        imgs = images_all[indices_class[c]].cpu()
                        features = net_for_feature.embed(imgs).detach()
                        mean = torch.mean(features, dim=0, keepdim=True)
                        dis = torch.norm(features - mean, dim=1)
                        rank = torch.argsort(dis)
                        idx_centers = rank[:1].tolist()
                        for i in range(args.ipc-1):
                            feature_centers = features[idx_centers]
                            if feature_centers.shape[0] == features.shape[1]:
                                feature_centers = feature_centers.unsqueeze(0)
                            dis_center = torch.cdist(features, feature_centers)
                            dis_min, _ = torch.min(dis_center, dim=-1)
                            id_max = torch.argmax(dis_min).item()
                            idx_centers.append(id_max)
                        image_for_train[c*args.ipc:(c+1)*args.ipc] = imgs[idx_centers].to(args.device)
                        label_for_train[c*args.ipc:(c+1)*args.ipc] = c

            elif args.coreset_type == 'herding':
                image_for_train = torch.zeros(num_classes*args.ipc, channel, im_size[0], im_size[1]).to(args.device)
                label_for_train = torch.zeros(num_classes*args.ipc, dtype=torch.long, requires_grad=False, device=args.device)
                net_for_feature = get_network(args.model, channel, num_classes, im_size).cpu()
                net_for_feature.load_state_dict(torch.load(f'results/WHOLE_DATASET_ConvNet/{args.dataset}/0/0_net_latest.pth'))
                for c in range(num_classes):
                    imgs = images_all[indices_class[c]].cpu()
                    features = net_for_feature.embed(imgs).detach()
                    mean = torch.mean(features, dim=0, keepdim=True)
                    idx_selected = []
                    idx_left = np.arange(features.shape[0]).tolist()
                    for i in range(args.ipc):
                        if len(idx_selected) > 0:
                            det = mean*(i+1) - torch.sum(features[idx_selected], dim=0)
                        else:
                            det = mean*(i+1)
                        dis = torch.norm(det-features[idx_left], dim=1)
                        idx = torch.argmin(dis).item()
                        idx_selected.append(idx_left[idx])
                        del idx_left[idx]
                    image_for_train[c*args.ipc:(c+1)*args.ipc] = imgs[idx_selected].to(args.device)
                    label_for_train[c*args.ipc:(c+1)*args.ipc] = c

            else: # whole dataset
                image_for_train = images_all
                label_for_train = labels_all
                if 'GRL' in args.method:
                    domain_label_for_train = domains_all
            
            accs = []
            args.log_file = open(f'{args.save_path}/log_{"_".join(list(map(str, time.localtime()[0:-4])))}.txt', 'w+')
            args.epoch_eval_train = 300
            for it_eval in range(args.num_eval):
                if 'GRL' in args.method in args.method:
                    net_eval = get_network(args, args.model, channel, num_classes, im_size, True).to(args.device)
                    _, acc_train, acc_test, loss_train, loss_test = evaluate_synset_grl(it_eval, net_eval, image_for_train, label_for_train, domain_label_for_train, testloader, args)
                else:
                    net_eval = get_network(args, args.model, channel, num_classes, im_size).to(args.device)
                    _, acc_train, acc_test, loss_train, loss_test = evaluate_synset(it_eval, net_eval, image_for_train, label_for_train, testloader, args)
                accs.append(acc_test)
            print(f'\n\n# - {args.coreset_type} {args.dataset}\n')
            print(f'evaluate {len(accs)} random {args.model}, mean = {np.mean(accs)*100:.4f} std = {np.std(accs)*100:.4f}\n\n')
            args.log_file.write(f'\n\n# - {args.coreset_type} {args.dataset}\n')
            args.log_file.write(f'evaluate {len(accs)} random {args.model}, mean = {np.mean(accs)*100:.4f} std = {np.std(accs)*100:.4f}\n\n')
            args.log_file.close()

        else:
            # - Initialize the synthetic data
            image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
            label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]

            if args.init == 'real':
                print('initialize synthetic data from random real images')
                for c in range(num_classes):
                    image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data

            start_iter = 0

            # - Train
            optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
            optimizer_img.zero_grad()
            
            domain_device = 'cuda'

            domain_masks = {}
            for _ in range (args.nopd):
                domain_masks[_] = [torch.ones(image_syn.size(), device=domain_device) * args.domain_mask_init]
                domain_masks[_][0] = domain_masks[_][0].detach().to(domain_device).requires_grad_(True)
                domain_masks[_].append(torch.optim.SGD([domain_masks[_][0]], lr=args.lr_img, momentum=0.5))

            if args.checkpoint_file != 'None':
                _checkpoint = torch.load(args.checkpoint_file)
                image_syn = _checkpoint['data'][0].to(args.device).requires_grad_(True)
                label_syn = _checkpoint['data'][1].to(args.device).requires_grad_(False)
                optimizer_img.load_state_dict(_checkpoint['checkpointing']['optimizer'])
                start_iter = int(args.checkpoint_file.split('_')[-1].split('.')[0])
            
            print(f'# --- {args.embedding_weight}')

            print('%s training begins'%get_time())
            for it in range(start_iter, args.Iteration+1):

                # -- Train synthetic images
                net_class = get_network(args, args.model, channel, num_classes, im_size, domain=False).to(args.device)
                net_domain = get_network(args, args.model, channel, num_classes, im_size, domain=True, num_domains=args.nopd).to(args.device)
                net_class.train()
                net_domain.train()

                image_syn, label_syn, optimizer_img, loss_avg, domain_loss_avg, domain_masks = training_loop(args, net_class, net_domain, image_syn, label_syn, optimizer_img, get_images, num_classes, channel, im_size, it, domain_masks, wandb_run=wandb_run)

                if it%wandb_save_iter == 0:
                    print('%s iter = %04d, loss = %.4f, domain loss = %.4f' % (get_time(), it, loss_avg, domain_loss_avg))
                    if wandb_run != None:
                        wandb_run.log({'Loss': loss_avg}, step=it)
                        wandb_run.log({'Domain Loss': domain_loss_avg}, step=it)
                if it % 10 == 0:
                    torch.save({'data': [copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())]}, os.path.join(args.save_path, f'{args.method}_{args.model}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}_{args.dataset}_{args.ipc}_latest.pt'))
                if it in eval_it_pool and it != 0:
                    torch.save({'data': [copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())]}, os.path.join(args.save_path, f'{args.method}_{args.model}_{args.pseudo_domain_method}_{args.pseudo_domain_clustering}_{args.dataset}_{args.ipc}_{it}.pt'))
    
    
if __name__ == '__main__':
    main()


