import os
import time
import random
import copy
import argparse
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_network, get_time, DiffAugment, ParamDiffAug,get_args
from dataset import get_classes, dirichlet_distribution
from loss_fn import Distance_loss
def main():
    parser = argparse.ArgumentParser()
    args = get_args(parser)
    args.method = 'DM'
    args.device = torch.device("cpu" if args.cuda == -1 else "cuda:%d" % args.cuda)
    args.dsa_param = ParamDiffAug()
    args.dsa = False if args.dsa_strategy in ['none', 'None'] else True

    if args.data_name == 'cifar10':
        channel, im_size, num_classes = 3,(32,32),10 
        classes_per_node = 2
    elif args.data_name == 'cifar100':
        channel, im_size, num_classes = 3,(32,32),100 
        classes_per_node = 10
    elif args.data_name == 'tiny_imageNet':
        channel, im_size, num_classes = 3,(32,32),200 
        classes_per_node = 20
    else:
        print('error')

    if  not os.path.exists('distill_data'):
        os.makedirs('distill_data')
        print(f"Created directory {args.data_path} for dataset storage.")

    args.save_path = 'distill_data/' + args.data_name +'_'+ str(args.num_nodes) +'clients_'+ args.init +'_'+ str(args.ipc) +'ipc_'+ str(args.alpha) +'_'+ str(args.Iteration)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    if args.data_distribution == 'incomplete_label':
        dst_train, mean, std,_ = get_classes(args.data_name,args.data_path,args.num_nodes,classes_per_node,args.seed)
    else:
        dst_train,mean, std,_  = dirichlet_distribution(args.data_name,args.data_path,args.num_nodes,args.seed,args.least_nums,args.alpha)
    
    
    data_save = []


    indices_class = {client_id:[[] for _ in range(num_classes)] for client_id in range(args.num_nodes)}
    indices_class_flag = {client_id:[False for _ in range(num_classes)] for client_id in range(args.num_nodes)}

    images_all = {client_id: [torch.unsqueeze(dst_train[client_id][i][0], dim=0) for i in range(len(dst_train[client_id]))] for client_id in range(args.num_nodes)}
    labels_all = {client_id: [dst_train[client_id][i][1] for i in range(len(dst_train[client_id]))] for client_id in range(args.num_nodes)}

    for client_id in range(args.num_nodes):
        for i, lab in enumerate(labels_all[client_id]):
            indices_class[client_id][lab].append(i)
            if len(indices_class[client_id][lab]) >= args.ipc:
                indices_class_flag[client_id][lab] = True 

    
    for exp in range(args.num_exp):
        print('\n================== Exp %d ==================\n '%exp)
        print('Hyper-parameters: \n', args.__dict__)

        ''' get a random model'''
        net = get_network(args.model, channel, num_classes).to(args.device) 
        net.train()
        for param in list(net.parameters()):
            param.requires_grad = False

        embed = net.get_features 

        image_syn = {client_id: None for client_id in range(args.num_nodes)}
        optimizer_img = {client_id: None for client_id in range(args.num_nodes)}

        for client_id in range(args.num_nodes):
            images = torch.cat(images_all[client_id], dim=0).to(args.device)
    
            for c in range(num_classes):
                print('class c = %d: %d real images'%(c, len(indices_class[client_id][c])))

            def get_images(c, n): 
                idx_shuffle = np.random.permutation(indices_class[client_id][c])[:n]
                return images[idx_shuffle]

            for ch in range(channel):
                print('real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images[:, ch]), torch.std(images[:, ch])))

            ''' initialize the synthetic data '''
            image_syn[client_id] = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
          
            if args.init == 'real':
                print('initialize synthetic data from random real images')
                for c in range(num_classes):
                    if indices_class_flag[client_id][c]:
                        image_syn[client_id].data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().data
            else:
                print('initialize synthetic data from random noise')


            ''' training '''
            optimizer_img[client_id] = torch.optim.SGD([image_syn[client_id], ], lr=args.lr_img, momentum=0.5)
            optimizer_img[client_id].zero_grad()
            print('%s training begins'%get_time())
            
            distance_loss = Distance_loss(device=args.device)
            for it in range(args.Iteration+1):
                if it%500 == 0:
                    ''' visualize and save '''
                    save_name = os.path.join(args.save_path, 'client%d_vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(client_id,args.method, args.data_name, args.model, args.ipc, exp, it))
                    image_syn_vis = copy.deepcopy(image_syn[client_id].detach().cpu())
                    for ch in range(channel):
                        image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
                    image_syn_vis[image_syn_vis<0] = 0.0
                    image_syn_vis[image_syn_vis>1] = 1.0
                    save_image(image_syn_vis, save_name, nrow=args.ipc) 

                ''' update synthetic data '''
                all_real_features, all_syn_features = [],[]
                all_real_labels, all_syn_labels = [],[]

                loss = torch.tensor(0.0).to(args.device)
                for c in range(num_classes):
                    if indices_class_flag[client_id][c]:
                        img_real = get_images(c, args.batch_real)
                        img_syn = image_syn[client_id][c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))

                        if args.dsa:
                            seed = int(time.time() * 1000) % 100000
                            img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                            img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                        output_real = embed(img_real).detach()
                        output_syn = embed(img_syn)

                        loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

                        all_real_features.append(output_real)
                        all_syn_features.append(output_syn)
                        all_real_labels.append(torch.full((output_real.size(0),), c, device=output_real.device, dtype=torch.long))
                        all_syn_labels.append(torch.full((output_syn.size(0),), c, device=output_syn.device, dtype=torch.long))

                    else:
                        continue

                real_features = torch.cat(all_real_features, dim=0)
                real_labels = torch.cat(all_real_labels, dim=0)

                syn_features = torch.cat(all_syn_features, dim=0)
                syn_labels = torch.cat(all_syn_labels, dim=0)

                loss_Sup = 0.0005 * distance_loss(syn_features,real_features.detach(),syn_labels,real_labels)

                loss += loss_Sup

                optimizer_img[client_id].zero_grad()
                loss.backward()
                optimizer_img[client_id].step()
                loss_sum = loss.item()
                loss_Sup_sum = loss_Sup.item()

                if it%10 == 0:
                    print('%s iter = %05d, loss = %.4f, loss_Sup = %.4f' % (get_time(), it, loss_sum, loss_Sup_sum))

            
        virtual_images = []
        virtual_images_flag = [True for _ in range(num_classes)]
        for c in range(num_classes):
            virtual_class_images = []
            for client_id in range(args.num_nodes):
                if indices_class_flag[client_id][c]: 
                    virtual_class_images.append(copy.deepcopy(image_syn[client_id][c * args.ipc : (c + 1) * args.ipc].detach().cpu()))
            if virtual_class_images==[]:
                virtual_images_flag[c] = False
            else:
                virtual_images.append(torch.mean(torch.stack(virtual_class_images, dim=0), dim=0))
        virtual_images = torch.cat(virtual_images, dim=0) 
        virtual_images_label = torch.tensor([np.ones(args.ipc)*c for c in range(num_classes) if virtual_images_flag[c]], dtype=torch.long, requires_grad=False, device=args.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]
    
        save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.data_name, args.model, args.ipc, exp, it))
        image_syn_vis = copy.deepcopy(virtual_images.detach().cpu())
        for ch in range(channel):
            image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]
        image_syn_vis[image_syn_vis<0] = 0.0
        image_syn_vis[image_syn_vis>1] = 1.0
        save_image(image_syn_vis, save_name, nrow=args.ipc) 

        data_save.append([copy.deepcopy(virtual_images.detach().cpu()), copy.deepcopy(virtual_images_label.detach().cpu())])
        torch.save({'data': data_save}, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc_iter%d_exp%d_alpha%s.pt'%(args.method,args.data_name,args.model,args.ipc,it,exp, str(args.alpha))))
    
       
if __name__ == '__main__':
    main()


