import os
import argparse
import torch
import torch.nn as nn
from tqdm import tqdm
from utils import *
import copy
from omegaconf import OmegaConf

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def main(args):
    torch.set_num_threads(args.torch_num_threads)
    args.dsa = True if args.dsa == 'True' else False
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.dsa_param = ParamDiffAug()

    args.channel, args.im_size, args.num_classes, _, class_map, _, _, _, dst_train, dst_test, _, _ = get_dataset(args.dataset, args.data_path, args.batch_real, args.res, args=args)
    
    ae_config = OmegaConf.load(args.ae_config)
    ae_model = load_autoencoder_from_config(ae_config, args.ae_ckpt).to(args.device)
    args.latent_size = (args.im_size[0] // args.f, args.im_size[1] // args.f)
    if args.lpc is None:
        args.lpc = get_lpc(args)
    args.convnet_pooling = 'avgpooling' if args.latent_size[0] >= 2 ** (args.train_depth + 1) else 'none'
    latent_all, label_all, _ = build_dataset(args, ae_model, dst_train, class_map, batch_size = 16 if args.latent_size[0] <= 64 else 4)
    test_latent_all, test_label_all, _ = build_dataset(args, ae_model, dst_test, class_map, batch_size = 16 if args.latent_size[0] <= 64 else 4)

    args.save_path = os.path.join(args.buffer_path, f'{args.dataset}-{args.res}-l{args.latent_size[0]}_{args.model}-d{args.train_depth}w{args.train_width}')
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    criterion = nn.CrossEntropyLoss().to(args.device)

    trajectories = []

    dst_train = TensorDataset(copy.deepcopy(latent_all.detach()), copy.deepcopy(label_all.detach()))
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)
    dst_test = TensorDataset(copy.deepcopy(test_latent_all.detach()), copy.deepcopy(test_label_all.detach()))
    testloader = torch.utils.data.DataLoader(dst_test, batch_size=args.batch_test, shuffle=False, num_workers=0)

    print('%s training begins' % get_time())
    print(f'Dataset info: {args.dataset}, {args.channel} * {args.im_size[0]} * {args.im_size[1]}, {args.num_classes} classes')
    print('Args: ' + str(args.__dict__))

    for it in range(0, args.num_experts):

        ''' Train synthetic data '''
        teacher_net = get_network(args.model, args.C, args.num_classes, args.latent_size, depth = args.train_depth, width = args.train_width, convnet_pooling = args.convnet_pooling).to(args.device) # get a random model
        teacher_net.train()
        teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr = args.lr_net, momentum=args.mom_net, weight_decay=args.weight_decay)  # optimizer_img for synthetic data
        teacher_optim.zero_grad()

        timestamps = []

        timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

        lr_schedule = [args.train_epochs // 2 + 1]

        for e in range(args.train_epochs):

            train_loss, train_acc = epoch("train", dataloader=trainloader, net=teacher_net, optimizer=teacher_optim,
                                        criterion=criterion, args=args, aug = args.dsa)

            with torch.no_grad():
                test_loss, test_acc = epoch("latent_test", dataloader=testloader, net=teacher_net, optimizer=None,
                                            criterion=criterion, args=args, aug = False)

            print("{} Itr: {}\tEpoch: {}\tTrain Acc: {}\tTest Acc: {}".format(get_time(), it, e, train_acc, test_acc))

            timestamps.append([p.detach().cpu() for p in teacher_net.parameters()])

            if e in lr_schedule and args.decay:
                lr *= 0.1
                teacher_optim = torch.optim.SGD(teacher_net.parameters(), lr=lr, momentum=args.mom, weight_decay=args.l2)
                teacher_optim.zero_grad()

        trajectories.append(timestamps)

        if len(trajectories) == args.save_interval:
            n = 0
            while os.path.exists(os.path.join(args.save_path, "replay_buffer_{}.pt".format(n))):
                n += 1
            print("{} Saving {}".format(get_time(), os.path.join(args.save_path, "replay_buffer_{}.pt".format(n))))
            torch.save(trajectories, os.path.join(args.save_path, "replay_buffer_{}.pt".format(n)))
            trajectories = []


if __name__ == '__main__':
    import shared_args
    parser = shared_args.add_shared_args()
    parser.add_argument('--num_experts', type=int, default=100, help='training iterations')
    parser.add_argument('--buffer_path', type=str, default='./latent_mtt_buffer', help='buffer path')
    parser.add_argument('--train_epochs', type=int, default=50)
    parser.add_argument('--decay', action='store_true')
    parser.add_argument('--save_interval', type=int, default=10)
    args = parser.parse_args()
    main(args)


