import sys
import os
import argparse
import warnings
sys.path.append('./lib_ddif')
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

import torch
import torch.nn as nn
from tqdm import tqdm

from lib_ddif.utils import get_dataset, get_network, get_daparam, TensorDataset, epoch, ParamDiffAug, set_seed, save_and_print, build_tensor_dataset


def main(args):
    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()

    save_dir = os.path.join(args.buffer_path, args.dataset)
    if args.dataset == "imagenet":
        save_dir = os.path.join(save_dir, args.subset)
    if not args.zca:
        save_dir += "_NO_ZCA"
    save_dir = os.path.join(save_dir, args.model)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    args.log_path = f"{args.buffer_path}/log"
    if args.dataset == "imagenet":
        args.log_path += f"_{args.subset}"
    if not args.zca:
        args.log_path += "_NO_ZCA"
    args.log_path += f"_{args.model}.txt"

    channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test, testloader, loader_train_dict, class_map, class_map_inv = get_dataset(args.dataset, args.data_path, args.batch_real, args.subset, args=args)

    save_and_print(args.log_path, f'Hyper-parameters: {args.__dict__}')

    ''' organize the real dataset '''
    save_and_print(args.log_path, "BUILDING DATASET")
    images_all, labels_all = build_tensor_dataset(dst_train, batch_size=args.batch_real, workers=args.workers, class_map=class_map)

    indices_class = [[] for c in range(num_classes)]
    for i, lab in tqdm(enumerate(labels_all)):
        indices_class[lab].append(i)


    for c in range(num_classes):
        save_and_print(args.log_path, 'class c = %d: %d real images'%(c, len(indices_class[c])))

    for ch in range(channel):
        save_and_print(args.log_path, 'real images channel %d, mean = %.4f, std = %.4f'%(ch, torch.mean(images_all[:, ch]), torch.std(images_all[:, ch])))

    criterion = nn.CrossEntropyLoss().to(args.device)

    trajectories = []

    dst_train = TensorDataset(images_all, labels_all)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    ''' set augmentation for whole-dataset training '''
    args.dc_aug_param = get_daparam(args.dataset, args.model, args.model, None)
    args.dc_aug_param['strategy'] = 'crop_scale_rotate'
    save_and_print(args.log_path, f'DC augmentation parameters: {args.dc_aug_param}')

    for it in range(0, args.num_experts):

        ''' Train synthetic data '''
        teacher_net = get_network(args.model, channel, num_classes, im_size).to(args.device)
        teacher_net.train()
        lr = args.lr_teacher
        teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
        teacher_optim.zero_grad()

        timestamps = []

        timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

        lr_schedule = [args.train_epochs // 2 + 1]

        for e in range(args.train_epochs):

            train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim, criterion=criterion, args=args, aug=True, dsa_param=args.dsa_param)

            test_loss, test_acc = epoch("test", dataloader=testloader, net=teacher_net, optimizer=None, criterion=criterion, args=args, aug=False, dsa_param=args.dsa_param)

            save_and_print(args.log_path, "Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(it, e, train_acc, test_acc))

            timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

            if e in lr_schedule and args.decay:
                lr *= 0.1
                teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
                teacher_optim.zero_grad()

        trajectories.append(timestamps)

        if len(trajectories) == args.save_interval:
            n = 0
            while os.path.exists(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))):
                n += 1
            save_and_print(args.log_path, "Saving {}".format(os.path.join(save_dir, "replay_buffer_{}.pt".format(n))))
            torch.save(trajectories, os.path.join(save_dir, "replay_buffer_{}.pt".format(n)))
            trajectories = []


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='dataset')
    parser.add_argument('--subset', type=str, default='imagenette', help='subset')
    parser.add_argument('--res', type=int)
    parser.add_argument('--model', type=str, default='ConvNet', help='model')
    parser.add_argument('--num_experts', type=int, default=100, help='training iterations')
    parser.add_argument('--lr_teacher', type=float, default=0.01, help='learning rate for updating network parameters')
    parser.add_argument('--batch_train', type=int, default=256, help='batch size for training networks')
    parser.add_argument('--batch_real', type=int, default=256, help='batch size for real loader')
    parser.add_argument('--batch_test', type=int, default=128, help='batch size for real loader')
    parser.add_argument('--dsa', type=str, default='True', choices=['True', 'False'], help='whether to use differentiable Siamese augmentation.')
    parser.add_argument('--dsa_strategy', type=str, default='color_crop_cutout_flip_scale_rotate', help='differentiable Siamese augmentation strategy')
    parser.add_argument('--data_path', type=str, default='../data', help='dataset path')
    parser.add_argument('--buffer_path', type=str, default='../buffers', help='buffer path')
    parser.add_argument('--train_epochs', type=int, default=50)
    parser.add_argument('--zca', action='store_true')
    parser.add_argument('--decay', action='store_true')
    parser.add_argument('--mom', type=float, default=0, help='momentum')
    parser.add_argument('--l2', type=float, default=0, help='l2 regularization')
    parser.add_argument('--save_interval', type=int, default=10)
    parser.add_argument('--workers', type=int, default=8)

    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()
    set_seed(args.seed)

    main(args)