import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
import random
import csv
import statistics

from utils import (
    get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset,
    get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment,
    ParamDiffAug, epoch2
)

# -----------------------------
# Profiling helpers
# -----------------------------
def _cuda_sync_all():
    if not torch.cuda.is_available():
        return
    for d in range(torch.cuda.device_count()):
        torch.cuda.synchronize(d)

def _cuda_reset_peak_all():
    if not torch.cuda.is_available():
        return
    for d in range(torch.cuda.device_count()):
        torch.cuda.reset_peak_memory_stats(d)

def _cuda_mem_alloc_all():
    if not torch.cuda.is_available():
        return 0
    return sum(torch.cuda.memory_allocated(d) for d in range(torch.cuda.device_count()))

def _cuda_max_mem_alloc_all():
    if not torch.cuda.is_available():
        return 0
    return sum(torch.cuda.max_memory_allocated(d) for d in range(torch.cuda.device_count()))

def profile_block(device: str, fn):
    """
    Returns: (out, dt_seconds, extra_peak_bytes, delta_alloc_bytes)
      extra_peak_bytes = peak_alloc_during_fn - alloc_at_start
      delta_alloc_bytes = alloc_at_end - alloc_at_start
    """
    if device.startswith("cuda") and torch.cuda.is_available():
        _cuda_sync_all()
        _cuda_reset_peak_all()
        start_alloc = _cuda_mem_alloc_all()
        t0 = time.perf_counter()

        out = fn()

        _cuda_sync_all()
        dt = time.perf_counter() - t0
        peak_alloc = _cuda_max_mem_alloc_all()
        end_alloc = _cuda_mem_alloc_all()

        return out, dt, (peak_alloc - start_alloc), (end_alloc - start_alloc)

    # CPU fallback (Python allocs only; does NOT reflect total RSS)
    import tracemalloc
    tracemalloc.start()
    t0 = time.perf_counter()
    out = fn()
    dt = time.perf_counter() - t0
    _, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    return out, dt, peak, peak


# -----------------------------
# Barycenter + discrepancies
# -----------------------------
def barycenter_general_lbfgs(group_means, D, max_iter=25):
    """
    Solve M* = argmin_m ∑_a D(Φ_a, m) using LBFGS.

    group_means: (G, d)
    D: function (group_means, m) -> scalar loss
    where m is (1, d) (broadcast over G).
    """
    m = group_means.mean(dim=0, keepdim=True).detach().clone()  # (1, d)
    m.requires_grad_(True)

    optimizer = torch.optim.LBFGS([m], lr=1.0, max_iter=max_iter, line_search_fn='strong_wolfe')

    def closure():
        optimizer.zero_grad()
        loss = D(group_means, m)
        loss.backward()
        return loss

    optimizer.step(closure)
    return m.detach().squeeze(0)  # (d,)


def D_cosine_l2(group_means, m, alpha=1.0, beta=0.1, eps=1e-8):
    diff = group_means - m
    l2_sq = diff.pow(2).sum(dim=1).mean()

    gm = group_means / (group_means.norm(dim=1, keepdim=True) + eps)
    m_norm = m / (m.norm(dim=1, keepdim=True) + eps)
    cos_sim = (gm * m_norm).sum(dim=1).mean()

    return alpha * (1.0 - cos_sim) + beta * l2_sq


def D_l2_squared(group_means, m):
    diff = group_means - m
    return (diff.pow(2).sum(dim=1)).mean()

def D_l2(group_means, m, eps=1e-8):
    diff = group_means - m
    return (diff.pow(2).sum(dim=1) + eps).sqrt().mean()

def D_cosine(group_means, m, eps=1e-8):
    gm = group_means / (group_means.norm(dim=1, keepdim=True) + eps)
    m_norm = m / (m.norm(dim=1, keepdim=True) + eps)
    cos_sim = (gm * m_norm).sum(dim=1)
    return (1.0 - cos_sim).mean()

def D_huber(group_means, m, delta=1.0):
    diff = group_means - m
    abs_diff = diff.abs()
    sq = 0.5 * diff.pow(2)
    per_dim = torch.where(abs_diff <= delta, sq, delta * abs_diff - 0.5 * delta**2)
    return per_dim.sum(dim=1).mean()

def D_l1(group_means, m):
    diff = group_means - m
    return diff.abs().sum(dim=1).mean()

def D_linf(group_means, m):
    diff = group_means - m
    return diff.abs().amax(dim=1).mean()

def D_linf_smooth(group_means, m, tau=1e-2):
    diff = (group_means - m).abs()
    return (tau * torch.logsumexp(diff / tau, dim=1)).mean()

def D_l1_smooth(group_means, m, eps=1e-6):
    diff = group_means - m
    return (diff.pow(2) + eps).sqrt().sum(dim=1).mean()


def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10_S_90', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=10, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode')
    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=1, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=300, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=1.0, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=1224, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='dsa strategy')
    parser.add_argument('--data_path', type=str, default='data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result-time', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--shuffle', type=bool, default=False, help='distance metric')
    parser.add_argument('--FairDD', action='store_true', help='Enable FairDD')
    parser.add_argument('--group_balance', type=bool, default=False, help='distance metric')

    # ---- profiling args ----
    parser.add_argument('--profile_metrics', action='store_true', help='Profile barycenter runtime/memory')
    parser.add_argument('--profile_every', type=int, default=10, help='Profile every N iters (reduces overhead)')
    parser.add_argument('--profile_warmup', type=int, default=50, help='Skip first N iters for profiling')
    parser.add_argument('--profile_prefix', type=str, default='metric_profile', help='CSV prefix')

    for datasets in [
        "UTKface",
    ]:
        args = parser.parse_args()
        args.method = 'DM'
        args.outer_loop, args.inner_loop = get_loops(args.ipc)
        args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        args.dsa_param = ParamDiffAug()
        args.dsa = True
        args.dataset = datasets

        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)

        eval_it_pool = [args.Iteration]
        print('eval_it_pool: ', eval_it_pool)

        channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(
            args.dataset, args.data_path
        )
        model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

        # restore previous random state for fairness between metrics
        load_random_state(random_state)

        # summary rows for this dataset (one row per metric)
        profile_summary_rows = []

        for bary_dist in ['avg', 'l2', 'cosine', 'huber', 'l1', 'linf']:

            accs_all_exps = dict()
            for key in model_eval_pool:
                accs_all_exps[key] = []

            data_save = []

            # per-metric profiling containers
            profile_iters = []
            profile_time_s = []
            profile_extra_peak_B = []
            profile_delta_alloc_B = []

            # organize the real dataset
            images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
            labels_all = [int(dst_train[i][1]) for i in range(len(dst_train))]
            color_all = [int(dst_train[i][2]) for i in range(len(dst_train))]
            images_all = torch.cat(images_all, dim=0).to(args.device)
            labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)
            color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)

            args.num_classes = len(torch.unique(labels_all))
            args.num_groups = len(torch.unique(color_all))

            indices_class = [[] for _ in range(args.num_classes)]
            for i, lab in enumerate(labels_all):
                indices_class[int(lab)].append(i)

            def get_images(c, n):
                idx_shuffle = np.random.permutation(indices_class[c])[:n]
                return images_all[idx_shuffle], labels_all[idx_shuffle], color_all[idx_shuffle]

            # initialize the synthetic data
            image_syn = torch.randn(
                size=(args.num_classes * args.ipc, channel, im_size[0], im_size[1]),
                dtype=torch.float, requires_grad=True, device=args.device
            )
            label_syn = torch.tensor(
                [np.ones(args.ipc) * i for i in range(num_classes)],
                dtype=torch.long, requires_grad=False, device=args.device
            ).view(-1)

            color_syn = torch.zeros_like(label_syn)
            for c in range(args.num_classes):
                image_data, _, color_data = get_images(c, args.ipc)
                image_syn.data[c * args.ipc:(c + 1) * args.ipc] = image_data.detach().data
                color_syn.data[c * args.ipc:(c + 1) * args.ipc] = color_data.detach().data

            # training
            optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5)
            optimizer_img.zero_grad()
            print()
            print('%s training begins' % get_time())

            # discrepancies
            Ds = {
                "l2_squared": D_l2_squared,
                "l2":         D_l2,
                "cosine":     D_cosine,
                "cosine_l2":  D_cosine_l2,
                "huber":      D_huber,
                "l1":         D_l1,
                "l1_smooth":  D_l1_smooth,
                "linf":       D_linf,
                "linf_smooth":D_linf_smooth,
            }

            D_kwargs = {
                "l2": {"eps": 1e-8},
                "cosine": {"eps": 1e-8},
                "cosine_l2": {"alpha": 1.0, "beta": 0.1, "eps": 1e-8},
                "huber": {"delta": 1.0},
                "l1_smooth": {"eps": 1e-6},
                "linf_smooth": {"tau": 1e-2},
            }

            for it in range(args.Iteration + 1):

                # decide once per iteration (applies to all classes)
                do_profile = (
                    args.profile_metrics
                    and (it >= args.profile_warmup)
                    and (it % args.profile_every == 0)
                )

                # per-iteration accumulators (barycenter-only)
                iter_bary_time_s = 0.0
                iter_bary_extra_peak_B = 0
                iter_bary_delta_alloc_B = 0


                # Train synthetic data
                net = get_network(args.model, channel, args.num_classes, im_size).to(args.device)
                net.train()
                criterion = nn.CrossEntropyLoss().to(args.device)
                optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)

                image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())
                dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(
                    dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0
                )

                for il in range(10):
                    _, _, net = epoch2('train', trainloader, net, optimizer_net, criterion, args, aug=True if args.dsa else False)

                for param in list(net.parameters()):
                    param.requires_grad = False

                embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed
                loss_avg = 0

                # update synthetic data
                loss = torch.tensor(0.0).to(args.device)
                for c in range(args.num_classes):
                    img_real, label, color = get_images(c, args.batch_real)
                    img_syn = image_syn[c * args.ipc:(c + 1) * args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    output_real = embed(img_real).detach()
                    output_syn = embed(img_syn)

                    unique_groups = torch.unique(color_all)
                    group_means = []
                    syn_mean = torch.mean(output_syn, dim=0)

                    for g in unique_groups:
                        mask = (color == g)
                        if mask.sum().item() == 0:
                            continue
                        mu_g = embed(img_real[mask])
                        mu_g = torch.mean(mu_g, dim=0)
                        group_means.append(mu_g)

                    group_means = torch.stack(group_means, dim=0)  # (G, d)

                    def compute_barycenter():
                        if bary_dist == 'avg':
                            return torch.mean(group_means, dim=0)
                        D_fn = Ds[bary_dist]
                        kwargs = D_kwargs.get(bary_dist, {})
                        D_wrapped = lambda G, m, D_fn=D_fn, kwargs=kwargs: D_fn(G, m, **kwargs)
                        return barycenter_general_lbfgs(group_means, D_wrapped, max_iter=80)

                    if do_profile:
                        real_barycenter, dt, extra_peak_B, delta_alloc_B = profile_block(args.device, compute_barycenter)
                        iter_bary_time_s += dt
                        iter_bary_extra_peak_B = max(iter_bary_extra_peak_B, int(extra_peak_B))
                        iter_bary_delta_alloc_B = max(iter_bary_delta_alloc_B, int(delta_alloc_B))
                    else:
                        real_barycenter = compute_barycenter()

                    L_vec = real_barycenter.detach() - syn_mean
                    loss += L_vec.abs().sum()

                if do_profile:
                    profile_iters.append(it)
                    profile_time_s.append(iter_bary_time_s)
                    profile_extra_peak_B.append(iter_bary_extra_peak_B)
                    profile_delta_alloc_B.append(iter_bary_delta_alloc_B)

                optimizer_img.zero_grad()
                loss.backward()
                optimizer_img.step()

                loss_avg += loss.item()
                loss_avg /= (args.num_classes)

                if it % 100 == 0:
                    print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))



            # -----------------------------
            # After metric loop: print/save profiling
            # -----------------------------
            if args.profile_metrics and len(profile_time_s) > 0:
                mean_t = statistics.mean(profile_time_s)
                std_t = statistics.pstdev(profile_time_s) if len(profile_time_s) > 1 else 0.0

                mean_peak_MB = statistics.mean(profile_extra_peak_B) / (1024 ** 2)
                std_peak_MB = (statistics.pstdev(profile_extra_peak_B) / (1024 ** 2)) if len(profile_extra_peak_B) > 1 else 0.0

                mean_delta_MB = statistics.mean(profile_delta_alloc_B) / (1024 ** 2)
                std_delta_MB = (statistics.pstdev(profile_delta_alloc_B) / (1024 ** 2)) if len(profile_delta_alloc_B) > 1 else 0.0

                print(
                    f"\n[PROFILE] dataset={args.dataset} metric={bary_dist} "
                    f"(barycenter-only; {len(profile_time_s)} sampled iters)\n"
                    f"  time/iter: {mean_t * 1000:.3f} ± {std_t * 1000:.3f} ms\n"
                    f"  extra peak alloc: {mean_peak_MB:.3f} ± {std_peak_MB:.3f} MB\n"
                    f"  delta alloc end-start: {mean_delta_MB:.3f} ± {std_delta_MB:.3f} MB\n"
                )

                detail_csv = os.path.join(args.save_path, f"{args.profile_prefix}_{args.dataset}_{bary_dist}_detail.csv")
                with open(detail_csv, "w", newline="") as f:
                    w = csv.writer(f)
                    w.writerow(["dataset", "metric", "iter", "bary_time_ms", "extra_peak_MB", "delta_alloc_MB"])
                    for it_i, t_i, p_i, d_i in zip(profile_iters, profile_time_s, profile_extra_peak_B, profile_delta_alloc_B):
                        w.writerow([args.dataset, bary_dist, it_i, t_i * 1000.0, p_i / (1024 ** 2), d_i / (1024 ** 2)])

                profile_summary_rows.append([
                    args.dataset, bary_dist, len(profile_time_s),
                    mean_t * 1000.0, std_t * 1000.0,
                    mean_peak_MB, std_peak_MB,
                    mean_delta_MB, std_delta_MB
                ])

        # write per-dataset summary csv (all metrics)
        if args.profile_metrics and len(profile_summary_rows) > 0:
            summary_csv = os.path.join(args.save_path, f"{args.profile_prefix}_{args.dataset}_summary.csv")
            with open(summary_csv, "w", newline="") as f:
                w = csv.writer(f)
                w.writerow([
                    "dataset", "metric", "n_iters_sampled",
                    "mean_bary_time_ms", "std_bary_time_ms",
                    "mean_extra_peak_MB", "std_extra_peak_MB",
                    "mean_delta_alloc_MB", "std_delta_alloc_MB"
                ])
                w.writerows(profile_summary_rows)

            print(f"[PROFILE] Wrote summary CSV: {summary_csv}\n")


if __name__ == '__main__':
    def save_random_state():
        return {
            'torch': torch.get_rng_state(),
            'np': np.random.get_state(),
            'random': random.getstate(),
            'cuda': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
        }

    def load_random_state(state):
        torch.set_rng_state(state['torch'])
        np.random.set_state(state['np'])
        random.setstate(state['random'])
        if torch.cuda.is_available() and state.get('cuda') is not None:
            torch.cuda.set_rng_state_all(state['cuda'])

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    random_state = save_random_state()
    main()
