import os
import time
import random
import argparse
import numpy as np
from tqdm import trange

import wandb

import torch
from torchvision.utils import save_image, make_grid
import torch.distributed as dist
from torch.multiprocessing import Process

from data import get_dataset, DiffAugment, ParamDiffAug
from models.wrapper import get_model
from utils import evaluate, Logger, sum_params, get_linear_schedule_warmup_constant, default_args
from generator import SyntheticImageGenerator


def run_single_process(rank, args):
    dist.init_process_group("nccl", rank=rank, world_size=len(args.gpus))
    args.device = torch.device(f"cuda:{args.gpus[rank]}")

    args.dsa_param = ParamDiffAug()
    if args.dataset == 'SVHN':
        args.dsa_strategy = 'color_crop_cutout_scale_rotate'
    else:
        args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'
    
    ''' data path '''
    if rank == 0:
        os.makedirs(args.data_path, exist_ok=True)    
    ''' feature path'''
    args.feature_path = f'{args.feature_path}/{args.dataset}/{args.model}'    
    
    ''' data set '''
    channel, im_size, num_classes, normalize, _, _, testloader = get_dataset(args.dataset, args.data_path, load_train=False)
    own_classes = np.array_split(range(num_classes), len(args.gpus))[rank]
    
    ''' model eval pool'''
    if rank == 0:
        if args.model_eval_pool is None:
            args.model_eval_pool = [args.model]
        else:
            args.model_eval_pool = args.model_eval_pool.split("_")
        accs_all_exps = dict() # record performances of all experiments
        for key in args.model_eval_pool:
            accs_all_exps[key] = []
        data_save = []

    ''' save path '''
    save_path = f'{args.save_path}/{args.dataset}/{args.exp_name}'
    os.makedirs(save_path, exist_ok=True)
    if rank == 0:
        ''' logger '''
        logger = Logger(
            exp_name=args.exp_name,
            save_dir=save_path,
            print_every=args.print_every,
            save_every=args.buffer_every,
            total_step=args.iteration,
            print_to_stdout=True,
            wandb_project_name=f'DD-clean-{args.dataset}',
            wandb_tags=[],
            wandb_config=args,
        )

    ''' initialize '''
    generator = SyntheticImageGenerator(
            num_classes, im_size, args.num_seed_vec, args.num_decoder, args.hdims,
            args.kernel_size, args.stride, args.padding).to(args.device)
    optimizer_gen = torch.optim.Adam(
        [{'params': generator.seed_vec, 'lr': args.lr_seed_vec},
        {'params': generator.decoders.parameters(), 'lr': args.lr_decoder}])
    if args.linear_schedule:
        scheduler_gen = get_linear_schedule_warmup_constant(
            optimizer_gen, 2000, args.iteration)
    else:
        scheduler_gen = torch.optim.lr_scheduler.MultiStepLR(
                optimizer_gen, milestones=args.lr_iteration, gamma=0.2)
            

    ''' load autoencoder '''
    if rank == 0:
        generator.load_state_dict(torch.load(args.ae_path))                
    del generator.encoders
    generator.broadcast_decoder()
    if rank != 0:
        for p in generator.parameters():
            p.data.zero_()
    sum_params(generator.state_dict().values())

    if rank == 0:                    
        logger.register_model_to_save(generator, "generator")
        logger.register_model_to_save(optimizer_gen, "optimizer_gen")
        logger.start()

    buffer = []
    for it in range(1, args.iteration+1):
        ''' buffer '''
        if len(buffer) == 0:
            if rank == 0:
                print("-------------------------Buffer Loading-------------------------")
            for i in trange(args.buffer_every, disable=rank!=0):
                buffer_it = it + i
                buffer.append(torch.load(f"{args.feature_path}/dsa_{buffer_it}.pth", map_location='cpu'))

        ''' get feature dict '''
        features_dict = buffer.pop(0)

        ''' get model '''
        net = get_model(args, args.model, channel, num_classes, im_size)
        if hasattr(net, "classifier"):
            del net.classifier
        if hasattr(net, 'fc'):
            del net.fc
        net.load_state_dict(features_dict['state_dict'], strict=False)
        net = net.to(args.device)
        net.train()
        for param in list(net.parameters()):
            param.requires_grad = False
        embed = net.embed

        ''' update synthetic data '''
        loss_avg = 0.0            
        sum_grad = [ torch.zeros_like(p) for p in generator.parameters()]
        for c in own_classes: 
            seed = features_dict[c]['seed']

            # real
            output_real_mean = features_dict[c]['mean'].to(args.device)

            # syn
            img_syn = generator.get_sample(c)[0]
            img_syn = DiffAugment(normalize(img_syn), args.dsa_strategy, seed=seed, param=args.dsa_param)
            output_syn_mean = torch.mean(embed(img_syn), dim=0)

            # compute loss
            loss = torch.sum((output_real_mean - output_syn_mean)**2)
            
            # compute grad
            optimizer_gen.zero_grad()
            grad = torch.autograd.grad(loss, generator.parameters())
            sum_grad = [ g+fg for g, fg in zip(sum_grad, grad) ]

            loss_avg += loss

            # for memory issue
            del grad, output_real_mean, img_syn, output_syn_mean, loss

        # update
        sum_params(sum_grad)
        optimizer_gen.zero_grad()
        for p, g in zip(generator.parameters(), sum_grad):
            if p.grad is None:
                p.grad = g.data
            else:
                p.grad.copy_(g.data)
        optimizer_gen.step()
        scheduler_gen.step()

        dist.all_reduce(loss_avg)            
        loss_avg = loss_avg.item() / num_classes  

        if rank == 0:
            logger.meter("loss", "train", loss_avg)

            ''' Evaluate synthetic data '''
            if it % args.eval_every == 0:
                image_syn, label_syn = generator.get_all_cpu()
                image_syn, label_syn = image_syn.detach(), label_syn.detach()               
                
                ''' Evaluate all model_eval '''
                if not args.not_eval:
                    for model_eval in args.model_eval_pool:
                        print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it)) 
                        accs = []
                        num_eval = args.num_eval if it == args.iteration else 1
                        for _ in range(num_eval):
                            net_eval = get_model(args, model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                            _, acc = evaluate(args, net_eval, image_syn, label_syn, testloader, normalize)
                            accs.append(acc)

                        if it == args.iteration: # record the final results
                            accs_all_exps[model_eval] += accs
                    
                        logger.meter("accuracy", model_eval, np.mean(accs))

                        print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))             

                ''' visualize and save '''
                save_name = os.path.join(save_path, f'{it}.png')
                grid = make_grid(image_syn, nrow=args.num_seed_vec*args.num_decoder)
                wandb.log({"images": wandb.Image(grid.detach().cpu())}, step=it)
                save_image(image_syn, save_name, nrow=args.num_seed_vec*args.num_decoder)

            if it == args.iteration: # only record the final results
                image_syn, label_syn = generator.get_all_cpu()
                image_syn, label_syn = image_syn.detach(), label_syn.detach()     

                data_save.append([image_syn, label_syn])
                torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(save_path, f'res.pth'))

            logger.step()
            
    if rank == 0:
        logger.finish()

    if rank == 0:
        print('\n==================== Final Results ====================\n')
        for key in args.model_eval_pool:
            accs = accs_all_exps[key]        
            print('Train on %s, Evaluate on %s for %d: mean  = %.2f%%  std = %.2f%%'%(args.model, key, len(accs), np.mean(accs), np.std(accs)))
            with open(f'{args.save_path}/{args.dataset}/{args.exp_name}/{key}_final_results.txt', 'w') as f:
                f.write(f'mean = {np.mean(accs)}, std = {np.std(accs)}')

def main(args):
    os.environ["WANDB_SILENT"] = "true"
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(1000, 9999))
    args.gpus = args.gpus.split(",")

    torch.multiprocessing.set_start_method('spawn')
    processes = []    
    for rank in range(len(args.gpus)):
        p = Process(target=run_single_process, args=(rank, args))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')   

    # data
    parser.add_argument('--data_path', type=str, default='ANONYMIZED')
    parser.add_argument('--feature_path', type=str, default='ANONYMIZED')
    parser.add_argument('--dataset', type=str, default='CIFAR10')

    # save
    parser.add_argument('--save_path', type=str, default='results')
    parser.add_argument('--exp_name', type=str, default=None)

    # repeat
    parser.add_argument('--num_eval', type=int, default=3)
    
    # training
    parser.add_argument('--iteration', type=int, default=20000)
    parser.add_argument('--lr_iteration', type=list, default=[])
    parser.add_argument('--model', type=str, default='ConvNet')
    parser.add_argument('--ipc', type=int, default=1)

    # hparms for ours
    parser.add_argument('--lr_seed_vec', type=float, default=1e-2)
    parser.add_argument('--lr_decoder', type=float, default=1e-3)
    parser.add_argument('--linear_schedule', action='store_true')
    parser.add_argument('--hdims', type=list, default=[])
    parser.add_argument('--num_seed_vec', type=int, default=16)
    parser.add_argument('--num_decoder', type=int, default=8)

    # evaluation
    parser.add_argument('--model_eval_pool', type=str, default=None)
    parser.add_argument('--epoch', type=int, default=200)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--batch', type=int, default=128)

    parser.add_argument('--print_every', type=int, default=100)
    parser.add_argument('--eval_every', type=int, default=5000)
    parser.add_argument('--buffer_every', type=int, default=100)
    parser.add_argument('--gpus', type=str, default="0")
    parser.add_argument('--not_reproduce', action='store_true')    
    parser.add_argument('--not_eval', action='store_true')

    args = parser.parse_args()

    assert args.iteration % args.eval_every == 0    
    if args.buffer_every <= 0:
        args.buffer_every = args.eval_every    

    ''' default '''
    if not args.not_reproduce:
        default_args(args)
        args.ae_path = f'./pretrained_ae/{args.dataset}_{args.ipc}_default.pth'

        # iteration
        args.iteration = 20000        

        # model
        args.model = "ConvNet" 

        # lr
        args.lr_seed_vec = 1e-1
        args.lr_decoder = 1e-2

        # evaluation
        args.epoch = 200
        args.lr = 0.01
        args.batch = 256

    if args.exp_name is None:
        args.exp_name = f'{args.model}_{args.ipc}_{args.num_seed_vec}_{args.num_decoder}'
        if args.linear_schedule:
            args.exp_name += "_linear_schedule"
    print(args.exp_name)

    main(args)


