import os, time, copy
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from utils import *
import random

# Added
from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter

def main(args):
    torch.set_num_threads(args.torch_num_threads)
    torch.random.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    if args.loop is None:
        args.outer_loop, args.inner_loop = get_loops(args.ipc) 
    else:
        args.outer_loop, args.inner_loop = [int(x) for x in args.loop.split('-')]
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    args.dsa_param = ParamDiffAug()
    args.dsa = True if args.dsa == 'True' else False
    if args.dsa_strategy in ['none', 'None']:
        args.dsa = False

    if not os.path.exists(args.data_path):
        os.mkdir(args.data_path)

    args.channel, args.im_size, args.num_classes, _, class_map, _, _, _, dst_train, _, testloader, _ = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
    eval_it_pool = np.arange(0, args.Iteration + 1, args.eval_it).tolist()
    model_eval_pool = get_eval_pool(args.eval_mode, args.model, args.model)

    accs_all_exps = dict()  # record performances of all experiments
    for key in model_eval_pool:
        accs_all_exps[key] = []

    # args.distributed = torch.cuda.device_count() > 1

    # Construct autoencoder
    ae_config = OmegaConf.load(args.ae_config)
    ae_model = load_autoencoder_from_config(ae_config, args.ae_ckpt).to(args.device)

    args.latent_size = (args.im_size[0] // args.f, args.im_size[1] // args.f)
    if args.lpc is None:
        args.lpc = get_lpc(args)
    args.convnet_pooling = 'avgpooling' if args.latent_size[0] >= 2 ** (args.train_depth + 1) else 'none'
    latent_all, label_all, indices_class = build_dataset(args, ae_model, dst_train, class_map, batch_size = 16 if args.latent_size[0] <= 64 else 4)

    args.save_path = os.path.join(args.save_path, 'DC', f'{args.dataset}-{args.im_size[0]}', get_run_name(args))
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path, exist_ok = False)
    logger = Logger(os.path.join(args.save_path, 'log.txt'))
    tb_writer = SummaryWriter(args.save_path)

    def get_latent(c, n):  # get random n real latents from class c
        idx_shuffle = np.random.permutation(indices_class[c])[:n]
        return latent_all[idx_shuffle].to(args.device)

    latent_syn, label_syn = prepare_latent(args, get_latent)
    optimizer_latent = get_optimizer_latent(args, latent_syn)

    criterion = nn.CrossEntropyLoss().to(args.device)

    logger.log('%s training begins' % get_time())
    logger.log('Evaluation iteration pool: ' + print_eval_it_pool(eval_it_pool))
    logger.log('Evaluation model pool: ', model_eval_pool)
    logger.log(f'Dataset info: {args.dataset}, {args.channel} * {args.im_size[0]} * {args.im_size[1]}, {args.num_classes} classes')
    logger.log('Args: ' + str(args.__dict__))

    for it in range(args.Iteration+1):
        if it in eval_it_pool:
            best_acc, best_std = eval_and_save(args, latent_syn, label_syn, ae_model, logger, testloader=testloader, model_eval_pool=model_eval_pool, it=it)
            for model_eval in best_acc.keys():
                tb_writer.add_scalar(f'best_acc/{model_eval}', best_acc[model_eval], it)
                tb_writer.add_scalar(f'best_std/{model_eval}', best_std[model_eval], it)
            tb_writer.flush()
        elif args.save_image_it and it % args.save_image_it == 0:
            save(args, latent_syn, ae_model, it = it)

        ''' Train synthetic data '''
        net = get_network(args.model, args.C, args.num_classes, args.latent_size, depth = args.train_depth, width = args.train_width, convnet_pooling = args.convnet_pooling).to(args.device) # get a random model
        net.train()
        net_parameters = list(net.parameters())
        optimizer_net = get_optimizer_net(args, net)
        optimizer_net.zero_grad()

        loss_avg, gw_real_l1_avg, gw_syn_l1_avg, gl_syn_l1_avg = 0, 0, 0, 0
        for ol in range(args.outer_loop):
            ''' freeze the running mu and sigma for BatchNorm layers '''
            # Synthetic data batch, e.g. only 1 image/batch, is too small to obtain stable mu and sigma.
            # So, we calculate and freeze mu and sigma for BatchNorm layer with real data batch ahead.
            # This would make the training with BatchNorm layers easier.
            BN_flag = False
            BNSizePC = 16  # for batch normalization
            for module in net.modules():
                if 'BatchNorm' in module._get_name(): #BatchNorm
                    BN_flag = True
            if BN_flag:
                latent_real = torch.cat([get_latent(c, BNSizePC) for c in range(args.num_classes)], dim=0)
                net.train() # for updating the mu, sigma of BatchNorm
                output_real = net(latent_real) # get running mu, sigma
                for module in net.modules():
                    if 'BatchNorm' in module._get_name():  #BatchNorm
                        module.eval() # fix mu and sigma of every BatchNorm layer

            loss = 0
            for c in range(args.num_classes):
                lat_real = get_latent(c, args.batch_real)
                lab_real = torch.ones((lat_real.shape[0],), device=args.device, dtype=torch.long) * c
                lat_syn = latent_syn[c*args.lpc:(c+1)*args.lpc].reshape((args.lpc, args.C, *args.latent_size))
                lab_syn = torch.ones((args.lpc,), device=args.device, dtype=torch.long) * c

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    lat_real = DiffAugment(lat_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    lat_syn = DiffAugment(lat_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                output_real = net(lat_real)
                loss_real = criterion(output_real, lab_real)
                gw_real = torch.autograd.grad(loss_real, net_parameters)
                gw_real = list((_.detach().clone() for _ in gw_real))

                output_syn = net(lat_syn)
                loss_syn = criterion(output_syn, lab_syn)
                gw_syn = torch.autograd.grad(loss_syn, net_parameters, create_graph=True)

                loss += match_loss(gw_syn, gw_real, dis_metric = args.dis_metric)
                with torch.no_grad():
                    gw_syn_l1, gw_real_l1 = compute_grad_l1(gw_syn), compute_grad_l1(gw_real)
                    gw_syn_l1_avg += gw_syn_l1.item()
                    gw_real_l1_avg += gw_real_l1.item()
                # del lat_real, output_real, loss_real, gw_real, output_syn, loss_syn, gw_syn, loss

            optimizer_latent.zero_grad()
            loss.backward()
            optimizer_latent.step()
            loss_avg += loss.item()
            with torch.no_grad():
                gl_syn_l1_avg += compute_grad_l1(latent_syn.grad).item()

            if ol == args.outer_loop - 1:
                del net, loss, loss_real, loss_syn, output_real, output_syn, gw_real, gw_syn, lat_real, lat_syn, lab_real, lab_syn
                break

            ''' update network '''
            if args.data_train_net == 'syn':
                latent_syn_train, label_syn_train = copy.deepcopy(latent_syn.detach()), copy.deepcopy(label_syn.detach())  # avoid any unaware modification
                dst_syn_train = TensorDataset(latent_syn_train, label_syn_train)
                trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
                for il in range(args.inner_loop):
                    epoch('train', trainloader, net, optimizer_net, criterion, args, aug = args.dsa, target_iter = 1)
            elif args.data_train_net == 'real':
                dst_real_train = TensorDataset(latent_all, label_all)
                trainloader = torch.utils.data.DataLoader(dst_real_train, batch_size = args.batch_train, shuffle = True, num_workers = 0)
                epoch('train', trainloader, net, optimizer_net, criterion, args, aug = args.dsa, target_iter = args.inner_loop)

        loss_avg /= (args.num_classes * args.outer_loop)
        gw_syn_l1_avg /= (args.num_classes * args.outer_loop)
        gw_real_l1_avg /= (args.num_classes * args.outer_loop)
        gl_syn_l1_avg /= args.outer_loop
        tb_writer.add_scalar('avg_match_loss/mse', loss_avg, it)
        tb_writer.add_scalar('avg_gw_l1/syn', gw_syn_l1_avg, it)
        tb_writer.add_scalar('avg_gw_l1/real', gw_real_l1_avg, it)
        tb_writer.add_scalar('avg_gl_l1/syn', gl_syn_l1_avg, it)
        tb_writer.flush()

        if it % args.log_it == 0:
            logger.log('%s iter = %04d, loss = %.4f' % (get_time(), it, loss_avg))

        # if it == args.Iteration: # only record the final results
        #     data_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])
        #     torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))
    
    tb_writer.close()

if __name__ == '__main__':
    import shared_args
    parser = shared_args.add_shared_args()
    parser.add_argument('--loop', type = str, help='manually set outer_loop-inner_loop, e.g. 10-50')
    parser.add_argument('--dis_metric', type=str, default='mse', help='distance metric')
    parser.add_argument('--data_train_net', type = str, choices = ['real', 'syn'], default = 'real')
    args = parser.parse_args()

    main(args)


