import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug, epoch2
import random
# def barycenter_general(group_means, D, num_steps=20, lr=0.1, normalize=False):
#     """
#     Generic barycenter solver:
#     M* = argmin_m ∑_a D(Φ_a, m)
#     group_means: (G, d)
#     D: function (group_means, m) -> scalar
#     """
#     # init at arithmetic mean (good starting point)
#     m = group_means.mean(dim=0, keepdim=True).detach().clone()  # (1, d)
#     m.requires_grad_(True)

#     optimizer = torch.optim.SGD([m], lr=lr)

#     for _ in range(num_steps):
#         optimizer.zero_grad()
#         loss = D(group_means, m)
#         loss.backward()
#         optimizer.step()

#         if normalize:
#             with torch.no_grad():
#                 m /= (m.norm(dim=1, keepdim=True) + 1e-8)

#     return m.detach().squeeze(0)  # (d,)



def barycenter_general_lbfgs(group_means, D, max_iter=25):
    """
    Solve M* = argmin_m ∑_a D(Φ_a, m) using LBFGS.

    group_means: (G, d)
    D: function (group_means, m) -> scalar loss
    where m is (1, d) (broadcast over G).
    """
    # init at arithmetic mean
    m = group_means.mean(dim=0, keepdim=True).detach().clone()  # (1, d)
    m.requires_grad_(True)

    optimizer = torch.optim.LBFGS([m], lr=1.0, max_iter=max_iter, line_search_fn='strong_wolfe')

    def closure():
        optimizer.zero_grad()
        loss = D(group_means, m)
        loss.backward()
        return loss

    optimizer.step(closure)

    return m.detach().squeeze(0)  # (d,)

def D_cosine_l2(group_means, m, alpha=1.0, beta=0.1, eps=1e-8):
    """
    Hybrid discrepancy:
    D(u, v) = α (1 - cos(u, v)) + β ||u - v||_2^2
    """
    # L2^2 term
    diff = group_means - m  # (G, d) - (1, d)
    l2_sq = diff.pow(2).sum(dim=1).mean()

    # cosine term
    gm = group_means / (group_means.norm(dim=1, keepdim=True) + eps)
    m_norm = m / (m.norm(dim=1, keepdim=True) + eps)
    cos_sim = (gm * m_norm).sum(dim=1).mean()

    return alpha * (1.0 - cos_sim) + beta * l2_sq


def D_l2_squared(group_means, m):
    # ∑_a ||Φ_a - m||_2^2 / G
    diff = group_means - m  # broadcast (G, d) - (1, d)
    return (diff.pow(2).sum(dim=1)).mean()

def D_l2(group_means, m, eps=1e-8):
    # ∑_a ||Φ_a - m||_2 / G
    diff = group_means - m
    return (diff.pow(2).sum(dim=1) + eps).sqrt().mean()

def D_cosine(group_means, m, eps=1e-8):
    # ∑_a (1 - cos(Φ_a, m)) / G
    gm = group_means / (group_means.norm(dim=1, keepdim=True) + eps)
    m_norm = m / (m.norm(dim=1, keepdim=True) + eps)  # (1, d)
    cos_sim = (gm * m_norm).sum(dim=1)
    return (1.0 - cos_sim).mean()


def D_huber(group_means, m, delta=1.0):
    """
    Robust discrepancy:
    per-dimension Huber: quadratic near 0, linear in the tails.
    """
    diff = group_means - m  # (G, d)
    abs_diff = diff.abs()
    sq = 0.5 * diff.pow(2)

    per_dim = torch.where(abs_diff <= delta,
                        sq,
                        delta * abs_diff - 0.5 * delta**2)
    # sum over dimensions, average over groups
    return per_dim.sum(dim=1).mean()

def D_l1(group_means, m):
    """
    ∑_a ||Φ_a - m||_1 / G
    """
    diff = group_means - m  # (G, d)
    return diff.abs().sum(dim=1).mean()


def D_linf(group_means, m):
    """
    ∑_a ||Φ_a - m||_∞ / G
    (nonsmooth; subgradients exist, but LBFGS may be less stable)
    """
    diff = group_means - m  # (G, d)
    return diff.abs().amax(dim=1).mean()


def D_linf_smooth(group_means, m, tau=1e-2):
    """
    Smooth approximation to L∞ using log-sum-exp:
    ||x||_∞ ≈ tau * log(sum_i exp(|x_i|/tau))
    Smaller tau -> closer to true L∞ but sharper optimization.
    """
    diff = (group_means - m).abs()  # (G, d)
    return (tau * torch.logsumexp(diff / tau, dim=1)).mean()


def D_l1_smooth(group_means, m, eps=1e-6):
    """
    Smooth approximation to L1 using pseudo-Huber-like:
    |x| ≈ sqrt(x^2 + eps)
    """
    diff = group_means - m  # (G, d)
    return (diff.pow(2) + eps).sqrt().sum(dim=1).mean()




def main():

    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10_S_90', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=10, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode') # S: the same to training model, M: multi architectures,  W: net width, D: net depth, A: activation function, P: pooling layer, N: normalization layer,
    # parser.add_argument('--num_exp', type=int, default=5, help='the number of experiments')
    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=6, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=2000, help='epochs to train a model with synthetic data') # it can be small for speeding up with little performance drop
    parser.add_argument('--Iteration', type=int, default=3000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=1000, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result-DM-exact', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--shuffle', type=bool, default=False, help='distance metric')
    parser.add_argument('--FairDD', action='store_true', help='Enable FairDD')
    parser.add_argument('--group_balance', type=bool, default=False, help='distance metric')

    for datasets in [
                    # # "CIFAR10_S_90",
                    # "UTKface",
                    # "BFFHQ",
                    "Colored_FashionMNIST_foreground",
                    # "Colored_FashionMNIST_background",
                    # "Colored_MNIST_foreground",
                    # "Colored_MNIST_background"
                    ]:
    



        args = parser.parse_args()
        args.method = 'DM'
        args.outer_loop, args.inner_loop = get_loops(args.ipc)
        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        args.dsa_param = ParamDiffAug()
        args.dsa = True
        args.dataset = datasets


        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)

        eval_it_pool = [args.Iteration]
        print('eval_it_pool: ', eval_it_pool)
        channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
        model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

        # 恢复之前的随机状态
        load_random_state(random_state)

        # for bary_dist in ['l2_squared', 'l2', 'cosine', 'cosine_l2', 'huber', 'l1', 'l1_smooth', 'linf', 'linf_smooth']:
        # for bary_dist in ['avg', 'l2', 'cosine',  'huber', 'l1', 'linf']:
        for bary_dist in ['huber', 'linf','l1']:
        # for bary_dist in ['l1', 'linf', 'avg']:
        # for bary_dist in ['avg']:



            accs_all_exps = dict() # record performances of all experiments
            for key in model_eval_pool:
                accs_all_exps[key] = []

            data_save = []



            ''' organize the real dataset '''
            images_all = []
            labels_all = []
            color_all = []

            images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
            labels_all = [int(dst_train[i][1]) for i in range(len(dst_train))]
            color_all = [int(dst_train[i][2]) for i in range(len(dst_train))]
            images_all = torch.cat(images_all, dim=0).to(args.device)
            labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
            color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)

            args.num_classes = len(torch.unique(labels_all))
            args.num_groups = len(torch.unique(color_all))

            indices_class = [[] for c in range(args.num_classes)]
            for i, lab in enumerate(labels_all):
                indices_class[lab].append(i)


            def get_images(c, n): # get random n images from class c
                idx_shuffle = np.random.permutation(indices_class[c])[:n]
                return images_all[idx_shuffle], labels_all[idx_shuffle], color_all[idx_shuffle]

            # for ch in range(channel):
            #     print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

            ''' initialize the synthetic data '''
            image_syn = torch.randn(size=(args.num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
            label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
            color_syn = torch.zeros_like(label_syn)
            for c in range(args.num_classes):
                image_data, _, color_data = get_images(c, args.ipc)
                image_syn.data[c*args.ipc:(c+1)*args.ipc] = image_data.detach().data
                color_syn.data[c*args.ipc:(c+1)*args.ipc] = color_data.detach().data

            ''' training '''
            optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic data
            optimizer_img.zero_grad()
            print()
            print('%s training begins'%get_time())

            for it in range(args.Iteration+1):
                ''' Evaluate synthetic data '''
                if it in eval_it_pool:
                    model_eval = model_eval_pool[0]
                    print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))
                    accs = []
                    max_Equalized_Odds_list = []
                    mean_Equalized_Odds_list = []
                    for it_eval in range(args.num_eval):
                        net_eval = get_network(model_eval, channel, args.num_classes, im_size).to(args.device) # get a random model
                        image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification
                        # _, acc_train, acc_test, max_Equalized_Odds, mean_Equalized_Odds = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        _, acc_train, acc_test, max_Equalized_Odds, mean_Equalized_Odds, max_Sufficiency, mean_Sufficiency = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                        accs.append(acc_test)
                        max_Equalized_Odds_list.append(max_Equalized_Odds)
                        mean_Equalized_Odds_list.append(mean_Equalized_Odds)
                        # torch.save({'net': net_eval.state_dict()}, os.path.join(args.save_path,'res_%s_%s_%s_%sori.pt' % (args.method, args.dataset,args.model,it_eval)))
                    # print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))
                    print('dataset: %s, bary_dist: %s' % (args.dataset, bary_dist))
                    print('\taccs, max_Equalized_Odds, mean_Equalized_Odds',np.round(np.mean(accs), 4), np.round(np.mean(max_Equalized_Odds_list), 4), np.round(np.mean(mean_Equalized_Odds_list), 4))
                    print('\tSTD: accs, max_Equalized_Odds, mean_Equalized_Odds',np.round(np.std(accs), 3), np.round(np.std(max_Equalized_Odds_list), 3), np.round(np.std(mean_Equalized_Odds_list), 3),'\n\n')

                    if it == args.Iteration: # record the final results
                        accs_all_exps[model_eval] += accs

                    ''' visualize and save '''
                    save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, 1, it))
                    image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                    for ch in range(channel):
                        image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                    image_syn_vis[image_syn_vis<0] = 0.0
                    image_syn_vis[image_syn_vis>1] = 1.0
                    save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.

                ''' Train synthetic data '''
                net = get_network(args.model, channel, args.num_classes, im_size).to(args.device) # get a random model
                net.train()
                criterion = nn.CrossEntropyLoss().to(args.device)
                optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net) 
                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                

                for il in range(10):
                    _,_, net = epoch2('train', trainloader, net, optimizer_net, criterion, args, aug = True if args.dsa else False)

                for param in list(net.parameters()):
                    param.requires_grad = False
                

                embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel
                loss_avg = 0



                Ds = {
                    "l2_squared": D_l2_squared,
                    "l2":         D_l2,
                    "cosine":     D_cosine,
                    "cosine_l2":  D_cosine_l2,
                    "huber":      D_huber,
                    "l1":         D_l1,
                    "l1_smooth":  D_l1_smooth,
                    "linf":       D_linf,
                    "linf_smooth":D_linf_smooth,
                }

                # --- per-distance hyperparams (only used when needed) ---
                D_kwargs = {
                    "l2": {"eps": 1e-8},
                    "cosine": {"eps": 1e-8},
                    "cosine_l2": {"alpha": 1.0, "beta": 0.1, "eps": 1e-8},
                    "huber": {"delta": 1.0},
                    "l1_smooth": {"eps": 1e-6},
                    "linf_smooth": {"tau": 1e-2},
                }

                ''' update synthetic data '''
                loss = torch.tensor(0.0).to(args.device)
                for c in range(args.num_classes):
                    img_real, label, color = get_images(c, args.batch_real)
                    img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)


                
                    output_real = embed(img_real).detach()
                    output_syn = embed(img_syn)


                    unique_groups = torch.unique(color_all)
                    group_means = []
                    syn_mean = torch.mean(output_syn, dim=0)



                    for g in unique_groups:
                        mask = (color == g)
                        if mask.sum().item() == 0: continue  
                        mu_g = embed(img_real[mask])
                        mu_g = torch.mean(mu_g, dim=0)
                        group_means.append(mu_g)

                    group_means = torch.stack(group_means, dim=0)  # (G, d)
                    
                    if bary_dist == 'avg':
                        real_barycenter = torch.mean(group_means, dim=0)
                    else:

                        for name_bary, D_fn in Ds.items():
                            if bary_dist != name_bary: continue
                            kwargs = D_kwargs.get(name_bary, {})
                            D_wrapped = lambda G, m, D_fn=D_fn, kwargs=kwargs: D_fn(G, m, **kwargs)

                            real_barycenter = barycenter_general_lbfgs(
                                group_means,
                                D_wrapped,
                                max_iter=50, # 80
                            )

                    L_vec = real_barycenter.detach() - syn_mean
                    loss +=  L_vec.abs().sum()   # or L2, etc.



                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()
                loss_avg += loss.item()


                loss_avg /= (args.num_classes)

                if it%100 == 0:
                    print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))

                if it == args.Iteration: # only record the final results
                    data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                    torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc-%s.pt'%(args.method, args.dataset, args.model, args.ipc, bary_dist)))




if __name__ == '__main__':
    def save_random_state():
        return {
            'torch': torch.get_rng_state(),
            'np': np.random.get_state(),
            'random': random.getstate(),
            'cuda': torch.cuda.get_rng_state_all()
        }
    def load_random_state(state):
        torch.set_rng_state(state['torch'])
        np.random.set_state(state['np'])
        random.setstate(state['random'])
        torch.cuda.set_rng_state_all(state['cuda'])

    seed=42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 保存当前的随机状态
    random_state = save_random_state()

    main()
