import os
import time
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, \
    match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug

def adjust_learning_rate(optimizer, epoch, init_lr):
    """Decay the learning rate based on schedule"""
    lr = init_lr
    for milestone in [1200, 1600, 1800]:
        lr *= 0.5 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def main():
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--method', type=str, default='DC', help='DC/DSA')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--ipc', type=int, default=1, help='image(s) per class')
    parser.add_argument('--eval_mode', type=str, default='S', help='eval_mode')
    parser.add_argument('--num_exp', type=int, default=1, help='the number of experiments')
    parser.add_argument('--num_eval', type=int, default=2, help='the number of evaluating randomly initialized models')
    parser.add_argument('--epoch_eval_train', type=int, default=1000, help='epochs to train a model with synthetic data')
    parser.add_argument('--Iteration', type=int, default=2000, help='training iterations')
    parser.add_argument('--lr_img', type=float, default=0.1, help='learning rate for updating synthetic images')
    parser.add_argument('--lr_net', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real data')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--init', type=str, default='real', help='noise/real: initialize synthetic images from random noise or randomly sampled real images.')
    parser.add_argument('--dsa_strategy', type=str, default='None', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='../data', help='dataset path')
    parser.add_argument('--save_path', type=str, default='result', help='path to save results')
    parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric')
    parser.add_argument('--fourth_weight', type=float, default=1.0, help='weight for layer params')
    parser.add_argument('--third_weight', type=float, default=1.0, help='weight for layer params')
    parser.add_argument('--second_weight', type=float, default=1.0, help='weight for layer params')
    parser.add_argument('--first_weight', type=float, default=1.0, help='weight for layer params')
    parser.add_argument('--lambda_1', type=float, default=0.05, help='break outlooper threshold')
    parser.add_argument('--lambda_2', type=float, default=0.05, help='break innerlooper threshold')

    # for ipc in [100,50,10]:
    for ipc in [50]:
        for dataset in  [
            # "CIFAR10_S_90",
            #             "UTKface",
            #             "Colored_FashionMNIST_foreground",
                          "Colored_FashionMNIST_background",
                        #   "Colored_MNIST_foreground",
                        #   "Colored_MNIST_background",
                        ]:
                        



            args = parser.parse_args()
            args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
            args.outer_loop, args.inner_loop = get_loops(args.ipc)
            args.dsa_param = ParamDiffAug()
            args.dsa = False 
            args.dataset = dataset

            if ipc == 100:
                args.Iteration = 220
            elif ipc == 50:
                args.Iteration = 210
            else:
                args.Iteration = 350
            args.ipc = ipc
            args.outer_loop, args.inner_loop = get_loops(args.ipc)
            if not os.path.exists(args.data_path):
                os.mkdir(args.data_path)

            if not os.path.exists(args.save_path):
                os.mkdir(args.save_path)

            eval_it_pool = [args.Iteration]
            channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader = get_dataset(args.dataset, args.data_path)
            model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

            accs_all_exps = dict()
            for key in model_eval_pool:
                accs_all_exps[key] = []

            data_save = []

            for exp in range(args.num_exp):
                print(f'================== Exp {exp} ==================')
                print(f'Hyper-parameters: \n{args.__dict__}')
                print(f'Evaluation model pool: {model_eval_pool}')

                ''' Organize the real dataset '''
                indices_class = [[] for c in range(num_classes)]
                images_all = [torch.unsqueeze(dst_train[i][0], dim=0) for i in range(len(dst_train))]
                labels_all = [dst_train[i][1] for i in range(len(dst_train))]
                
                for i, lab in enumerate(labels_all):
                    indices_class[lab].append(i)
                
                images_all = torch.cat(images_all, dim=0).to(args.device)
                labels_all = torch.tensor(labels_all, dtype=torch.long, device=args.device)

                # Handle colors
                color_all = [dst_train[i][2] for i in range(len(dst_train))]
                color_all = torch.tensor(color_all, dtype=torch.long, device=args.device)
                args.num_groups = len(torch.unique(color_all))
                args.num_classes = len(torch.unique(labels_all))

                for c in range(num_classes):
                    print(f'class c = {c}: {len(indices_class[c])} real images')

                def get_images(c, n): 
                    idx_shuffle = np.random.permutation(indices_class[c])[:n]
                    return images_all[idx_shuffle], labels_all[idx_shuffle], color_all[idx_shuffle]

                ''' Initialize the synthetic data '''
                image_syn = torch.randn(size=(num_classes * args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
                label_syn = torch.tensor([np.ones(args.ipc) * i for i in range(num_classes)], dtype=torch.long, requires_grad=False, device=args.device).view(-1)

                if args.init == 'real':
                    print('Initialize synthetic data from random real images')
                    for c in range(num_classes):
                        image_data, _, _ = get_images(c, args.ipc)
                        image_syn.data[c*args.ipc:(c+1)*args.ipc] = image_data.detach().data
                else:
                    print('Initialize synthetic data from random noise')

                ''' Training '''
                optimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) 
                optimizer_img.zero_grad()
                criterion = nn.CrossEntropyLoss().to(args.device)
                criterion_sum = nn.CrossEntropyLoss(reduction='sum').to(args.device)
                MSE_Loss = nn.MSELoss(reduction='sum').to(args.device)
                
                print(f'{get_time()} training begins')

                for it in range(args.Iteration + 1):
                    adjust_learning_rate(optimizer_img, it, args.lr_img)

                    ''' Evaluate synthetic data '''
                    if it in eval_it_pool:
                        for model_eval in model_eval_pool:
                            print(f'-------------------------\nEvaluation\nmodel_train = {args.model}, model_eval = {model_eval}, iteration = {it}')
                            
                            if args.dsa:
                                args.dc_aug_param = None
                                print(f'DSA augmentation strategy: \n{args.dsa_strategy}')
                            else:
                                args.dc_aug_param = get_daparam(args.dataset, args.model, model_eval, args.ipc)  
                                print(f'DC augmentation parameters: \n{args.dc_aug_param}')

                            accs = []
                            max_Equalized_Odds_list  =[]
                            mean_Equalized_Odds_list =[]
                            
                            for it_eval in range(args.num_eval):
                                net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device)  
                                image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())  
                                _, acc_train, acc_test, max_Equalized_Odds, mean_Equalized_Odds, max_Sufficiency, mean_Sufficiency = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)
                                
                                accs.append(acc_test)
                                max_Equalized_Odds_list.append(max_Equalized_Odds)
                                mean_Equalized_Odds_list.append(mean_Equalized_Odds)

                            print(f'Evaluate {len(accs)} random {model_eval}, mean = {np.mean(accs):.4f} std = {np.std(accs):.4f}')
                            print(f'accs: {np.mean(accs)}, max_EO: {np.mean(max_Equalized_Odds_list)}, mean_EO: {np.mean(mean_Equalized_Odds_list)}')

                            if it == args.Iteration:  
                                accs_all_exps[model_eval] += accs

                        ''' Visualize and save '''
                        save_name = os.path.join(args.save_path, f'vis_{args.method}_{args.dataset}_{args.model}_{args.ipc}ipc_exp{exp}_iter{it}.png')
                        image_syn_vis = copy.deepcopy(image_syn.detach().cpu())
                        for ch in range(channel):
                            image_syn_vis[:, ch] = image_syn_vis[:, ch] * std[ch] + mean[ch]
                        image_syn_vis[image_syn_vis < 0] = 0.0
                        image_syn_vis[image_syn_vis > 1] = 1.0
                        save_image(image_syn_vis, save_name, nrow=args.ipc)  

                    ''' Train synthetic data '''
                    net = get_network(args.model, channel, num_classes, im_size).to(args.device)  
                    net.train()
                    optimizer_net = torch.optim.SGD(net.parameters(), lr=args.lr_net)
                    optimizer_net.zero_grad()
                    loss_avg = 0
                    args.dc_aug_param = None 
                #  if args.ipc in [50,100]:
                #      args.outer_loop = 60

                    for ol in range(args.outer_loop):
                        acc_watcher = list()
                        pop_cnt = 0
                        acc_test = 0.0
                        
                        loss = torch.tensor(0.0).to(args.device)
                        
                        # We need to collect images for batch processing if needed, 
                        # but loss calculation is per class based on feature matching
                        img_real_gather = []
                        lab_real_gather = []

                        for c in range(num_classes):
                            img_real, label, color = get_images(c, args.batch_real)
                            img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                            
                            if args.dsa:
                                seed = int(time.time() * 1000) % 100000
                                img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                                img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)
                            
                            img_real_gather.append(img_real)
                            lab_real_gather.append(torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c)

                            # Forward pass
                            output_real, real_features = net(img_real)
                            output_syn, syn_features = net(img_syn)

                            # --- NEW LOGIC: Average Representations Across Colors ---
                            unique_colors = torch.unique(color)
                            num_unique_colors = len(unique_colors)
                            
                            # Indices of layers we care about
                            target_layers = [0, -1, -2, -3, -4]
                            
                            # Initialize accumulators for averaged real features
                            # We create zero tensors matching the shape of the mean feature (dim 0)
                            avg_real_features = {}
                            for idx in target_layers:
                                feature_shape = real_features[idx].shape[1:] # shape excluding batch
                                avg_real_features[idx] = torch.zeros(feature_shape, device=args.device)

                            # 1. Accumulate means for each color
                            for col in unique_colors:
                                mask = (color == col)
                                for idx in target_layers:
                                    # Mean of real features for this specific color
                                    feat_mean = torch.mean(real_features[idx][mask], dim=0)
                                    avg_real_features[idx] += feat_mean

                            # 2. Average across colors
                            for idx in target_layers:
                                avg_real_features[idx] /= num_unique_colors

                            # 3. Calculate Loss against Synthetic Features
                            
                            # Layers 0, -1, -2 (Weight 1.0)
                            for idx in [0, -1, -2]:
                                loss += MSE_Loss(avg_real_features[idx], torch.mean(syn_features[idx], dim=0)) / num_classes
                            
                            # Layers -3, -4 (Weight 0.1)
                            for idx in [-3, -4]:
                                loss += 0.1 * MSE_Loss(avg_real_features[idx], torch.mean(syn_features[idx], dim=0)) / num_classes

                            # 4. Calculate Output/Alignment Loss
                            # We align the Averaged Real Feature (idx 0) with the Synthetic Feature (idx 0)
                            last_syn_feature = torch.mean(syn_features[0], dim=0)
                            
                            # Compute similarity/projection
                            # output shape: [1, feature_dim] @ [feature_dim, 1] -> [1, 1] (or similar depending on dimensions)
                            output_align = avg_real_features[0].unsqueeze(0) @ last_syn_feature.t()
                            
                            # Target label is the current class 'c'
                            target_var = torch.tensor([c], dtype=torch.long, device=args.device)
                            
                            loss += 0.000000000000000001 * criterion_sum(output_align, target_var.float()) / num_classes
                            #loss +=  criterion_sum(output_align, target_var.float()) / num_classes
                            # --------------------------------------------------------

                        # Combine for standard classification loss (if used for updating img or just stats)
                        if args.dataset in [ 'UTKface', 'BFFHQ']:
                            img_real_gather = torch.stack(img_real_gather, dim=0).reshape(args.batch_real * num_classes, 3, 64, 64)
                        else:
                            img_real_gather = torch.stack(img_real_gather, dim=0).reshape(args.batch_real * num_classes, 3, 32, 32)

                        lab_real_gather = torch.stack(lab_real_gather, dim=0).reshape(args.batch_real * num_classes)

                        output_real, _ = net(img_real_gather)
                        loss_real = criterion(output_real, lab_real_gather)
                        #loss += 1 * loss_real
                        loss += 0.00000000000000000001 * loss_real

                        loss.backward()
                        optimizer_img.step()
                        optimizer_img.zero_grad()
                        loss_avg += loss.item()

                        ############ Outloop testing ############
                        for c in range(num_classes):
                            img_real_test, _, _ = get_images(c, 256)
                            prob, _ = net(img_real_test)
                            acc_test += (torch.ones(len(img_real_test), device=args.device) * c == prob.max(dim=1)[1]).float().mean()
                        acc_test /= num_classes
                        acc_watcher.append(acc_test.detach().cpu())
                        
                        # Early stopping logic for outer loop
                        if len(acc_watcher) == 10:
                            if max(acc_watcher) - min(acc_watcher) < args.lambda_1:
                                break
                            else:
                                acc_watcher.pop(0)

                        ''' Update network '''
                        image_syn_train, label_syn_train = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach())
                        dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
                        trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                        
                        for il in range(args.inner_loop):
                            epoch('train', trainloader, net, optimizer_net, criterion, args, aug=True if args.dsa else False)

                    loss_avg /= (num_classes * args.outer_loop)

                    if it % 10 == 0:
                        print(f'{get_time()} iter = {it:04d}, loss = {loss_avg:.4f}')

                    if it == args.Iteration: 
                        data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
                        torch.save({'data': data_save, 'accs_all_exps': accs_all_exps}, os.path.join(args.save_path, f'res_CAFE_{args.dataset}_{args.model}_{args.ipc}ipc-MO.pt'))

            print('\n==================== Final Results ====================\n')
            for key in model_eval_pool:
                accs = accs_all_exps[key]
                print(f'Run {args.num_exp} experiments, train on {args.model}, evaluate {len(accs)} random {key}, mean = {np.mean(accs)*100:.2f}%  std = {np.std(accs)*100:.2f}%')

if __name__ == '__main__':
    main()

