import os
import time
from torch.utils.data import DataLoader, Subset
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, matchloss_idc
import random


from utils import orthogonality_loss_from_group_grads, _sanitize_grads


def decode_zoom(img, target, factor, size=(32, 32)):
    """
    IDC Multi-formation: Splits one synthetic image into factor^2 patches 
    and upsamples them back to the original size.
    """
    h = img.shape[-1]
    s_crop = h // factor
    resize = torch.nn.Upsample(size=size, mode='bilinear', align_corners=True)
    
    cropped = []
    for i in range(factor):
        for j in range(factor):
            h_loc = i * s_crop
            w_loc = j * s_crop
            # Extract the patch
            cropped.append(img[:, :, h_loc:h_loc + s_crop, w_loc:w_loc + s_crop])
            
    cropped = torch.cat(cropped, dim=0) # Now size is (IPC * factor^2, C, s_crop, s_crop)
    data_dec = resize(cropped)          # Upsample back to (IPC * factor^2, C, H, W)
    target_dec = torch.cat([target for _ in range(factor**2)])
    
    return data_dec, target_dec


def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, default='DC', help='DC/DSA')
    parser.add_argument('--dataset', type=str, default='CIFAR10_S_90', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=50, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,

    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=4, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')

    parser.add_argument('--Iteration', type=int, default=1000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=1, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.05, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--FairDD', action='store_true', help='Enable FairDD')


    args = parser.parse_args()
    args.outer_loop, args.inner_loop = get_loops(args.ipc)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.method = 'DC'
    args.dsa = False


    # if not os.path.exists(args.data_path):
    #     os.mkdir(args.data_path)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    # for skew in [0.6,0.65,0.7,0.75,0.8,0.85]:
    for skew in [0.7]:
    # for skew in [0.85]:
        args.skew = skew
            
        eval_it_pool = [args.Iteration]
        print('eval_it_pool: ', eval_it_pool)
        channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path, skew)
        model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

        # 恢复之前的随机状态
        load_random_state(random_state)

        accs_all_exps = dict() # record performances of all experiments
        for key in model_eval_pool:
            accs_all_exps[key] = []

        data_save = []


        for exp in range(args.num_exp):
            print('\n================== Exp %d ==================\n '%exp)
            print('Hyper-parameters: \n', args.__dict__)
            print('Evaluation model pool: ', model_eval_pool)

            ''' organize the real dataset '''
            images_all = []
            labels_all = []
            indices_class = [[] for c in range(num_classes)]

            images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
            labels_all = [dst_train[i][1] for i in range(len(dst_train))]
            color_all = [dst_train[i][2] for i in range(len(dst_train))]
            for i, lab in enumerate(labels_all):
                indices_class[lab].append(i)
            images_all = torch.cat(images_all, dim=0).to(args.device)
            labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
            color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)

            args.num_classes = len(torch.unique(labels_all))
            args.num_groups = len(torch.unique(color_all))

            def get_images(c, n): # get random n images from class c
                idx_shuffle = np.random.permutation(indices_class[c])[:n]
                return images_all[idx_shuffle], labels_all[idx_shuffle], color_all[idx_shuffle]



            ''' initialize the synthetic data '''
            # image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
            label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
            for c in range(num_classes):
                image_data, _, color_data = get_images(c, args.ipc)
                # image_syn.data[c * args.ipc:(c + 1) * args.ipc] = image_data.detach().data
                
    
            if args.dataset in ['BFFHQ', 'UTKface']:
                args.factor = 4
            else:
                args.factor = 2
            image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), 
                                    dtype=torch.float, requires_grad=True, device=args.device)

            for c in range(num_classes):
                img_real_init, _, _ = get_images(c, args.ipc) 
                image_syn.data[c*args.ipc : (c+1)*args.ipc] = img_real_init.detach()


            ''' training '''
            optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
            optimizer_img.zero_grad()
            criterion = nn.CrossEntropyLoss().to(args.device)
            print('%s training begins'%get_time())

            for it in range(args.Iteration+1):

                ''' Evaluate synthetic data '''
                if it in eval_it_pool:
                    for model_eval in model_eval_pool:
                        print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                        args.dc_aug_param = None

                        accs = []
                        max_Equalized_Odds_list = []
                        mean_Equalized_Odds_list = []
                        for it_eval in range(args.num_eval):
                            net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                            # image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        # 1. Decode the images first so the evaluation model sees the "information-intensive" patches
                            with torch.no_grad():
                                image_syn_eval, label_syn_eval = decode_zoom(image_syn.detach(), label_syn.detach(), args.factor, size=im_size)

                            _, acc_train, acc_test, max_Equalized_Odds, mean_Equalized_Odds, max_Sufficiency, mean_Sufficiency = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                            accs.append(acc_test)
                            max_Equalized_Odds_list.append(max_Equalized_Odds)
                            mean_Equalized_Odds_list.append(mean_Equalized_Odds)
                        print('IDC-MO skew: ', skew)
                        print('accs, max_Equalized_Odds, mean_Equalized_Odds', np.mean(accs),np.mean(max_Equalized_Odds_list), np.mean(mean_Equalized_Odds_list))
                        print('std accs, max_Equalized_Odds, mean_Equalized_Odds', np.std(accs),np.std(max_Equalized_Odds_list), np.std(mean_Equalized_Odds_list))

                        if it == args.Iteration: # record the final results
                            accs_all_exps[model_eval] += accs



                ''' Train synthetic data '''
                net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random model
                net.train()
                net_parameters = list(net.parameters())
                optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)  # optimizer_img for synthetic data
                optimizer_net.zero_grad()
                loss_avg = 0
                args.dc_aug_param = None 


                # trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                # for il in range(1):
                #     epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

                idx = torch.randperm(len(dst_train))[:200]
                trainloader = DataLoader(
                    Subset(dst_train, idx),
                    batch_size=args.batch_train,
                    shuffle=True,
                    num_workers=0
                )
            

                for il in range(1):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)


                if args.ipc in [50,500]:
                    args.outer_loop = 40
                for ol in range(args.outer_loop):
                    



                    ''' update synthetic data '''
                    loss = torch.tensor(0.0).to(args.device)
                    optimizer_img.zero_grad()
                    for c in range(num_classes):
                        # 1. Get real data batch
                        img_real, lab_real, color_rr = get_images(c, args.batch_real)
                        
                        # 2. Get synthetic data and DECODE it (The IDC Step)
                        raw_syn = image_syn[c*args.ipc:(c+1)*args.ipc]
                        raw_lab = label_syn[c*args.ipc:(c+1)*args.ipc]
                        
                        # This generates (IPC * factor^2) images
                        img_syn_decoded, lab_syn_decoded = decode_zoom(raw_syn, raw_lab, args.factor, size=im_size)

                        # 
                        # . Apply Augmentation (DSA)
                        if args.dsa:
                            seed = int(time.time() * 1000) % 100000
                            img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                            img_syn_decoded = DiffAugment(img_syn_decoded, args.dsa_strategy, seed=seed, param=args.dsa_param)

                        # 4. Gradient Matching (Standard IDC style)
                        # Match Real Gradients
                        
                        
                        output_syn = net(img_syn_decoded)
                        loss_syn = criterion(output_syn, lab_syn_decoded)
                        gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

        
                        output_real = net(img_real)
                        # --- PREPARE REAL GRADIENTS (Handle Imbalance Here) ---
                        unique_groups = torch.unique(color_rr)
                        group_grads = {}
                        
                        # Iterate over each group present in the current batch
                        for grp_idx in unique_groups:
                            mask = (color_rr == grp_idx)
                            if mask.sum() == 0: continue

                            loss_grp = criterion(output_real[mask], lab_real[mask])
                            g_grp = torch.autograd.grad(loss_grp, net_parameters, retain_graph=True)
                            group_grads[grp_idx.item()] = list((_.detach().clone() for _ in g_grp))

                        # 3. Construct BALANCED Real Gradient Target
                        gw_real_balanced = []
                        for i in range(len(gw_syn)): # Iterate over layers
                            layer_grads = [group_grads[k][i] for k in group_grads]
                            gw_real_balanced.append(torch.stack(layer_grads).mean(dim=0))
                        
                        args.dis_metric =='ours'
                        loss += match_loss(gw_syn, gw_real_balanced, args)

                    loss.backward()
                    loss_avg += loss.item()
                    optimizer_img.step()
                


  
    




                loss_avg /= (num_classes*args.outer_loop)
                if it % 20 == 0:
                    print('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))




if __name__ == '__main__':
    def save_random_state():
        return {
            'torch': torch.get_rng_state(),
            'np': np.random.get_state(),
            'random': random.getstate(),
            'cuda': torch.cuda.get_rng_state_all()
        }
    def load_random_state(state):
        torch.set_rng_state(state['torch'])
        np.random.set_state(state['np'])
        random.setstate(state['random'])
        torch.cuda.set_rng_state_all(state['cuda'])

    seed=42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 保存当前的随机状态
    random_state = save_random_state()

    main()

