import os.path
import sys, importlib
import argparse

import numpy as np

import torch
from torch.nn import functional as F
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
import torch.utils.data as data_utils

from util import *

sys.path.append('')
from mnist_loader import FMnistRotated
import datetime

from feddirt import Central




if __name__ == "__main__":
    # Training settings
    parser = argparse.ArgumentParser(description='FedDIRT')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed (default: 0)')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='input batch size for training (default: 64)')
    # parser.add_argument('--epochs', type=int, default=30,
    #                     help='number of epochs to train (default: 10)')
    parser.add_argument('--iters', type=int, default=1500,
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--num-supervised', default=1000, type=int,
                        help="number of supervised examples, /10 = samples per class")

    # Basic setting
    parser.add_argument('--list_train_domains', '--list', nargs='+', default=['0', '15', '30', '45', '60'],
                        help='domains used during training')
    parser.add_argument('--target_domain', type=str, default='75',
                        help='domain used during testing')
    parser.add_argument('--model', type=str, default='dirt')
    parser.add_argument('--dataset', type=str, default='RotatedFMnist')

    # StarGAN Model
    parser.add_argument('--d-dim', type=int, default=5,
                        help='number of classes')
    parser.add_argument('--x-dim', type=int, default=784,
                        help='input size after flattening')
    parser.add_argument('--y-dim', type=int, default=10,
                        help='number of classes')
    parser.add_argument('--zd-dim', type=int, default=64,
                        help='size of latent space 1')
    parser.add_argument('--zx-dim', type=int, default=64,
                        help='size of latent space 2')
    parser.add_argument('--zy-dim', type=int, default=64,
                        help='size of latent space 3')

    # Aux multipliers
    parser.add_argument('--aux_loss_multiplier_y', type=float, default=3500.,
                        help='multiplier for y classifier')
    parser.add_argument('--aux_loss_multiplier_d', type=float, default=2000.,
                        help='multiplier for d classifier')
    # Beta VAE part
    parser.add_argument('--beta_d', type=float, default=1.,
                        help='multiplier for KL d')
    parser.add_argument('--beta_x', type=float, default=1.,
                        help='multiplier for KL x')
    parser.add_argument('--beta_y', type=float, default=1.,
                        help='multiplier for KL y')


    parser.add_argument('-w', '--warmup', type=int, default=100, metavar='N',
                        help='number of epochs for warm-up. Set to 0 to turn warmup off.')
    parser.add_argument('--max_beta', type=float, default=1., metavar='MB',
                        help='max beta for warm-up')
    parser.add_argument('--min_beta', type=float, default=0.0, metavar='MB',
                        help='min beta for warm-up')

    # INB
    parser.add_argument('--use_shared', action='store_true', default=False,
                        help='Use shared space of INB')
    # AE
    parser.add_argument('--activation', default='sigmoid')
    parser.add_argument('--ae_dir', default='')

    # DIRT training
    #parser.add_argument('--trans', type=str, default='inb')
    parser.add_argument('--mnist_subset', type=str, default='0')
    parser.add_argument('--all-data', action='store_true', default=False,
                        help='whether to use all MNIST in the training')
    parser.add_argument('--sync_step', type=int, default=1)
    parser.add_argument('--eval_step', type=int, default=50)

    parser.add_argument('--extra', action='store_true',default=False)
    parser.add_argument('--reg',default=1,type=float)
    parser.add_argument('--uni-target', action='store_true', default=False)
    parser.add_argument('--sup-uni-target', action='store_true', default=False)


    # log
    parser.add_argument('--data_dir', type=str, default='')
    parser.add_argument('--model_dir', type=str, default='')
    parser.add_argument('--trans',type=str,default='stargan')
    parser.add_argument('--tn', type=str, default='inb')
    parser.add_argument('--outpath', type=str, default='./saved/',
                        help='where to save')
    parser.add_argument('--note',default='')

    args = parser.parse_args()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device(f"cuda:{args.device}" if args.cuda else "cpu")
    args.device = device
    kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}

    # Set seed
    torch.manual_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)

    # =================================================================================== #
    #                                    Saving and logging                               #
    # =================================================================================== #

    if args.tn == 'indaeinb':
        args.ae_dir += '/indae'

    if args.trans != 'stargan':
        args.model_dir = args.model_dir + args.trans + f'/{75}/inb.pt'
    run_name = args.trans
    if args.use_shared:
        run_name += '_shared'
    run_name += f'_r{args.reg}'
    run_name += f'_{args.note}'


    # Model name
    print(args.outpath)
    model_name = args.outpath + args.trans + '_domain_'+ str(args.target_domain) +  '_seed_' + str(args.seed)
    print(model_name)


    # =================================================================================== #
    #                                     Prepare data                                    #
    # =================================================================================== #
    # Choose training domains
    #all_training_domains = ['0', '15', '30', '45', '60', '75']
    #all_training_domains = ['0', '15', '30', '45', '60']
    #all_training_domains.remove(args.target_domain)
    #args.list_train_domains = all_training_domains

    print(args.target_domain, args.list_train_domains)
    args.n_domains = len(args.list_train_domains)


    train_loader_dict = dict()
    for i,domain in enumerate(args.list_train_domains):

        train_set = FMnistRotated([domain], [args.target_domain], args.data_dir,
                                 train=True, mnist_subset=args.mnist_subset, all_data=args.all_data)
        # change the domain label
        train_set.train_domain = torch.ones_like(train_set.train_domain) * i

        train_loader = data_utils.DataLoader(train_set,
                                             batch_size=args.batch_size,
                                             shuffle=True, **kwargs)
        train_loader_dict[domain] = train_loader


    test_set = FMnistRotated(args.list_train_domains, [args.target_domain], args.data_dir,
                            train=False, mnist_subset=args.mnist_subset, all_data=args.all_data)
    test_loader = data_utils.DataLoader(test_set,
                                        batch_size=args.batch_size,
                                        shuffle=True, **kwargs)
    #train_loader_dict['all'] = train_loader
    print(train_set.__len__())
    print(test_set.__len__())

    for iddx in range(10):
        xx = train_set.train_data
        xy = train_set.train_labels
        xd = train_set.train_domain
    for iddx in range(10):
        print(xx[xy == i].shape)

    # =================================================================================== #
    #                                     Prepare Model                                   #
    # =================================================================================== #

    activations = {'tanh':nn.Tanh(),
                   'sigmoid':nn.Sigmoid()}
    args.activation = activations[args.activation]


    central = Central(train_loader_dict,test_loader,args)



    tracker = central.train()
    save_name = args.trans.replace('/','-') + args.note
    if args.use_shared:
        save_name += 'share'
    torch.save(tracker,f'./saved/{save_name}.pt')


    #
