import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
import random


from utils import orthogonality_loss_from_group_grads, _sanitize_grads

def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, default='DC', help='DC/DSA')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,

    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1, help='epochs to train a model with synthetic data')

    parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=0.1, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--FairDD', action='store_true', help='Enable FairDD')


    ALL_DATA = [
                    "CIFAR10_S_90",
                    "Colored_FashionMNIST_foreground",
                    "Colored_FashionMNIST_background",
                    "Colored_MNIST_foreground",
                    "Colored_MNIST_background",
                    "UTKface",
                    "BFFHQ",
                                ]
    for dataset in ALL_DATA:



        args = parser.parse_args()
        args.outer_loop, args.inner_loop = get_loops(args.ipc)
        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        args.dataset = dataset

        # if not os.path.exists(args.save_path):
        #     os.mkdir(args.save_path)

        # channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
        load_random_state(random_state)


        # images_all = []
        # labels_all = []
        # indices_class = [[] for c in range(num_classes)]

        # images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
        # labels_all = [dst_train[i][1] for i in range(len(dst_train))]
        # color_all = [dst_train[i][2] for i in range(len(dst_train))]
        # for i, lab in enumerate(labels_all):
        #     indices_class[lab].append(i)
        # images_all = torch.cat(images_all, dim=0).to(args.device)
        # labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
        # color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)

        # args.num_classes = len(torch.unique(labels_all))
        # args.num_groups = len(torch.unique(color_all))


        # def get_images(c, n): # get random n images from class c
        #     idx_shuffle = np.random.permutation(indices_class[c])[:n]
        #     return images_all[idx_shuffle], labels_all[idx_shuffle], color_all[idx_shuffle]


        # image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
        # for c in range(num_classes):
        #     image_data, _, color_data = get_images(c, args.ipc)
        #     image_syn.data[c * args.ipc:(c + 1) * args.ipc] = image_data.detach().data



    
    # --- Loop Structure Fixed ---
    NAMES = ['DC','DM','IDC', 'CAFE']
    for name in NAMES:
        for dataset in ALL_DATA:
            for fair_crt in ['NoFair','FairDD','NoOrtho']:
                
                args.testMetric = name
                # Loop for IPC
                for ipc in [10, 50, 100]:

                    args.ipc = ipc
                    
                    if name == 'DC':
                        args.dsa = False
                    else:
                        args.dsa = True


                    dump_name = f"dump_{name}_{dataset}_ipc{args.ipc}_{fair_crt}.pt"
                    if os.path.exists(os.path.join('./T-SNE', dump_name)):
                        continue


                    save_path = './results-pt/' + name  + '/'+name +'-'+ fair_crt + '/'
                    if fair_crt == 'FairDD':
                        save_path = save_path + 'FairDD_'
                    elif fair_crt == 'NoOrtho':
                        save_path = save_path + 'Fair_NoOrtho_'
                        
                    save_path = save_path + name + '_' + dataset + '_ipc'  + str(args.ipc) + '/'
                    # save_path = save_path + 'res_'+name+'_' + dataset + '_ConvNet_'  + str(args.ipc) + 'ipc.pt'
                    # checkpoint = torch.load(save_path, map_location=args.device, weights_only=False)
                    save_name = save_path
                    if not os.path.exists(save_name):
                        print('No such file: %s'%save_name)
                        continue
                    
                    #  + 'vis_%s_%s_%s_%dipc_exp%d_initial.png'%(args.method, args.dataset, args.model, args.ipc, 1)
                    




                    
                    # image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                    # for ch in range(channel):
                    #     image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                    # image_syn_vis[image_syn_vis<0] = 0.0
                    # image_syn_vis[image_syn_vis>1] = 1.0
                    # try:
                    #     save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.
                    #     print('Saving initial synthetic images to %s'%save_name)
                    # except:
                    #     print('Error in saving initial synthetic images to %s'%save_name)


if __name__ == '__main__':
    def save_random_state():
        return {
            'torch': torch.get_rng_state(),
            'np': np.random.get_state(),
            'random': random.getstate(),
            'cuda': torch.cuda.get_rng_state_all()
        }
    def load_random_state(state):
        torch.set_rng_state(state['torch'])
        np.random.set_state(state['np'])
        random.setstate(state['random'])
        torch.cuda.set_rng_state_all(state['cuda'])

    seed=42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 保存当前的随机状态
    random_state = save_random_state()

    main()

