import os
import time
import copy
import numpy as np
import torch
import torch.nn as nn
from utils import *
import random

from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter

def main(args):
    torch.set_num_threads(args.torch_num_threads)
    torch.random.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.dsa == 'True' else False
    if args.dsa_strategy in ['none', 'None']:
        args.dsa = False

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    args.channel, args.im_size, args.num_classes, _, class_map, _, _, _, dst_train, _, testloader, _ = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
    eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    accs_all_exps = dict() # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    # args.distributed = torch.cuda.device_count() > 1

    # Construct autoencoder
    ae_config = OmegaConf.load(args.ae_config)
    ae_model = load_autoencoder_from_config(ae_config, args.ae_ckpt).to(args.device)

    args.latent_size = (args.im_size[0] // args.f, args.im_size[1] // args.f)
    if args.lpc is None:
        args.lpc = get_lpc(args)
    args.convnet_pooling = 'avgpooling' if args.latent_size[0] >= 2 ** (args.train_depth + 1) else 'none'
    latent_all, label_all, indices_class = build_dataset(args, ae_model, dst_train, class_map, batch_size = 16 if args.latent_size[0] <= 64 else 4)

    args.save_path = os.path.join(args.save_path, 'DM', f'{args.dataset}-{args.im_size[0]}', get_run_name(args))
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path, exist_ok = False)
    logger = Logger(os.path.join(args.save_path, 'log.txt'))
    tb_writer = SummaryWriter(args.save_path)

    def get_latent(c, n):  # get random n real latents from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return latent_all[idx_shuffle].to(args.device)

    latent_syn, label_syn = prepare_latent(args, get_latent)
    optimizer_latent = get_optimizer_latent(args, latent_syn)
    
    logger.log('%s training begins' % get_time())
    logger.log('Evaluation iteration pool: ' + print_eval_it_pool(eval_it_pool))
    logger.log('Evaluation model pool: ', model_eval_pool)
    logger.log(f'Dataset info: {args.dataset}, {args.channel} * {args.im_size[0]} * {args.im_size[1]}, {args.num_classes} classes')
    logger.log('Args: ' + str(args.__dict__))

    for it in range(args.Iteration+1):
        if it in eval_it_pool:
            best_acc, best_std = eval_and_save(args, latent_syn, label_syn, ae_model, logger, testloader=testloader, model_eval_pool=model_eval_pool, it=it)
            for model_eval in best_acc.keys():
                tb_writer.add_scalar(f'best_acc/{model_eval}', best_acc[model_eval], it)
                tb_writer.add_scalar(f'best_std/{model_eval}', best_std[model_eval], it)
            tb_writer.flush()
        elif args.save_image_it and it % args.save_image_it == 0:
            save(args, latent_syn, ae_model, it = it)

        ''' Train synthetic data '''
        net = get_network(args.model, args.C, args.num_classes, args.latent_size, depth = args.train_depth, width = args.train_width, convnet_pooling = args.convnet_pooling).to(args.device) # get a random model
        net.train()
        for param in list(net.parameters()):
            param.requires_grad = False

        embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallel

        loss_avg = 0
        ''' update synthetic data '''
        if 'BN' not in args.model: # for ConvNet
            loss = 0.0
            for c in range(args.num_classes):
                lat_real = get_latent(c, args.batch_real)
                lat_syn = latent_syn[c*args.lpc:(c+1)*args.lpc]

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    lat_real = DiffAugment(lat_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    lat_syn = DiffAugment(lat_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                output_real = embed(lat_real).detach()
                output_syn = embed(lat_syn)

                loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

        else: # for ConvNetBN
            images_real_all = []
            images_syn_all = []
            loss = torch.tensor(0.0).to(args.device)
            for c in range(args.num_classes):
                img_real = get_latent(c, args.batch_real)
                img_syn = latent_syn[c*args.lpc:(c+1)*args.lpc]

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                images_real_all.append(img_real)
                images_syn_all.append(img_syn)

            images_real_all = torch.cat(images_real_all, dim=0)
            images_syn_all = torch.cat(images_syn_all, dim=0)

            output_real = embed(images_real_all).detach()
            output_syn = embed(images_syn_all)

            loss += torch.sum((torch.mean(output_real.reshape(args.num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(args.num_classes, args.lpc, -1), dim=1))**2)

        optimizer_latent.zero_grad()
        loss.backward()
        optimizer_latent.step()
        loss_avg += loss.item()
        loss_avg /= (args.num_classes)
        tb_writer.add_scalar('avg_match_loss/dm', loss_avg, it)

        if it % args.log_it == 0:
            logger.log('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

        # if it == args.Iteration: # only record the final results
        #     data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
        #     torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_DM_%s_%s_%dipc.pt'%(args.dataset, args.model, args.ipc)))


if __name__ == '__main__':
    import shared_args
    parser = shared_args.add_shared_args()
    args = parser.parse_args()

    main(args)


