
import torch
from myvisdom import Visdom
import os, argparse
import numpy as np
import yaml, re

from make_model import make_GAN


def get_loader(args, img_size=None):
    batch_size = args.anchor_per_batch * args.data_per_anchor

    from torchvision import datasets, transforms

    if args.dset == 'mnist':
        ds = datasets.MNIST(root='datasets/mnist',
                            train=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                            ]),
                            download=True)
    elif args.dset == 'fashion_mnist':
        ds = datasets.FashionMNIST(root='datasets/fashion_mnist',
                            train=True,
                            transform=transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                            ]),
                            download=True)
    elif args.dset == 'chairs':
        assert img_size is not None
        ds = datasets.ImageFolder(root='datasets/chairs/',
                                  transform=transforms.Compose([
                                      transforms.Grayscale(),
                                      transforms.Resize((img_size, img_size)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=(0.5,), std=(0.5,))
                                  ]))
    elif args.dset == 'zap50k':
        assert img_size is not None
        ds = datasets.ImageFolder(root='datasets/ut-zap50k-images-square',
                                  transform=transforms.Compose([
                                      transforms.Resize((img_size, img_size)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=(0.5,), std=(0.5,))
                                  ]))
    else:
        raise NotImplementedError

    if args.exclude_label is None:
        real_loader = torch.utils.data.DataLoader(dataset=ds,
                                                  batch_size=batch_size,
                                                  shuffle=True)
    else:
        idx = [i for i,d in enumerate(ds) if d[1] != args.exclude_label]
        real_loader = torch.utils.data.DataLoader(dataset=ds,
                                                  batch_size=batch_size,
                                                  sampler=torch.utils.data.SubsetRandomSampler(idx) )
    return real_loader



def train(args, nick, arg_diff):
    """ Create dump file and directory """
    os.mkdir(os.path.join('dumps', nick))
    with open(os.path.join('dumps', nick, 'args.yaml'), 'w') as yf:
        yaml.safe_dump(vars(args), yf)
    if args.log_to_file:
        vis = Visdom(env=nick, log_to_filename=os.path.join('dumps', nick, 'vis_log'))
    else:
        vis = Visdom(env=nick)

    """ Parse arch args """
    with open(os.path.join('archs', args.arch+'.yaml'), 'r') as yf:
        arch = yaml.safe_load(yf)
    arch_rev = {'gen':{}, 'dis':{}}
    for k,v in vars(args).items():
        gen_mo = re.search(r'gen__(.+)', k)
        dis_mo = re.search(r'dis__(.+)', k)
        if gen_mo is not None and v is not None:
            assert args.load_from is None, "No changes are allowed from the loaded architecture"
            key = gen_mo.group(1)
            if key == 'n_latent':
                key = 'in_shape'
                v = [v]
            arch_rev['gen'][ key ] = v
            arch['gen'][ key ] = v
        if dis_mo is not None and v is not None:
            assert args.load_from is None, "No changes are allowed from the loaded architecture"
            key = dis_mo.group(1)
            arch_rev['dis'][ key ] = v
            arch['dis'][ key ] = v
    with open(os.path.join('dumps', nick, 'arch_rev.yaml'), 'w') as yf:
        yaml.safe_dump(arch_rev, yf)
    with open(os.path.join('dumps', nick, 'arch.yaml'), 'w') as yf:
        yaml.safe_dump(arch, yf)

    arg_text = []
    for k, v in vars(args).items():
        arg_text.append(f"<b>{k}={v}</b>" if k in arg_diff.keys() else f"{k}={v}")
    vis.text('<br>'.join(arg_text))

    """ Create Model"""
    if args.load_from is None:
        gan = make_GAN(arch, vis, nick)
    else:
        gan = load(args.load_from, load_to_train=True, vis=vis, new_nick=nick)

    """ Train """
    if len(arch['gen']['out_shape']) > 1:
        img_size = arch['gen']['out_shape'][1] # Square img shape
    else:
        img_size = None
    real_loader = get_loader(args, img_size=img_size)
    gan.train(real_loader, args.anchor_per_batch, args.data_per_anchor, args.lamb_bias_align,
              n_epoch=args.n_epoch, n_dis=args.n_dis, n_gen=args.n_gen, anchor_reset_period=args.anchor_reset_period,
              img_size=img_size, vis_period=args.vis_period, dump_period=args.dump_period)
    return gan


def load(nick, load_ep=None, load_to_train=False, vis=None, new_nick=None):
    if load_to_train:
        assert vis is not None and new_nick is not None
    else:
        vis = Visdom(env='test', use_incoming_socket=False)
    args = argparse.ArgumentParser().parse_args(args=[])
    with open(os.path.join('dumps', nick, 'args.yaml'), 'r') as af:
        vars(args).update( yaml.safe_load(af) )

    if load_ep is None:
        dump_match = [re.search(r'_ep(\d+).dump', f) for f in os.listdir(os.path.join('dumps', nick))]
        load_ep = np.max([int(dm.group(1)) for dm in dump_match if dm is not None])
        print("Loading ep{}".format(load_ep))

    with open(os.path.join('dumps', nick, 'arch.yaml'), 'r') as yf:
        arch = yaml.safe_load(yf)
    gan = make_GAN(arch, vis, new_nick if load_to_train else nick)
    gan.gen.load_state_dict( torch.load(os.path.join('dumps', nick, 'gen_ep{}.dump'.format(load_ep))) )
    gan.dis.load_state_dict( torch.load(os.path.join('dumps', nick, 'dis_ep{}.dump'.format(load_ep))) )
    if not load_to_train:
        gan.gen.eval()
        gan.dis.eval()
        print('Set as eval mode')
    return gan


if __name__ == "__main__":

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('dset', type=str, help='dataset name: (mnist | fashion_mnist | chairs | zap50k) ')
    parser.add_argument('arch', type=str, help='architecture yaml name')

    # gen options
    parser.add_argument('--nz', dest='gen__n_latent', type=int, help='number of latent dimension')
    parser.add_argument('--z_dist', dest='gen__latent_dist', type=str, help='latent distribution')
    parser.add_argument('--na', dest='gen__n_anchor', type=int, help='number of total anchors')
    parser.add_argument('--as', dest='gen__anchor_scale', type=float, help='scale multiplied by each anchor')
    parser.add_argument('--ada_scale', dest='gen__adaptive_scale', action='store_true', default=None, help='scale adaptive to num features')
    parser.add_argument('--bn_gen', dest='gen__bn', action='store_true', default=None, help='Use batch norm in Gen')
    parser.add_argument('--no_bn_gen', dest='gen__bn', action='store_false', default=None, help='No batch norm in Gen')

    # dis options
    parser.add_argument('--bn_dis', dest='dis__bn', action='store_true', default=None, help='Use batch norm in Dis')
    parser.add_argument('--no_bn_dis', dest='dis__bn', action='store_false', default=None, help='No batch norm in Dis')
    parser.add_argument('--dis_form', dest='dis__form', type=str, help='discriminator type')

    # training options
    parser.add_argument('--exclude_label', type=int, help='exclude certain label from training. only int label is supported')
    parser.add_argument('--apb', dest='anchor_per_batch', type=int, default=10, help='number of anchors per batch')
    parser.add_argument('--dpa', dest='data_per_anchor', type=int, default=10, help='number of data per anchor')
    parser.add_argument('--n_epoch', type=int, default=30, help='number of epochs')
    parser.add_argument('--n_dis', type=int, default=1, help='number of discriminator update per iteration')
    parser.add_argument('--n_gen', type=int, default=1, help='number of generator update per iteration')
    parser.add_argument('--arp', dest='anchor_reset_period', type=int, default=1, help='anchor reset period')
    parser.add_argument('--lambda', dest='lamb_bias_align', type=float, default=0.0005, help='lambda multiplied by bias align term')
    parser.add_argument('--load_from', dest='load_from', type=str, help='start training from the pretrained dump file given by nick')

    # visdom
    parser.add_argument('--vis_period', dest='vis_period', type=int, default=20, help='Visualization period (in iter)')
    parser.add_argument('--dump_period', dest='dump_period', type=int, default=1, help='Dumping period (in epoch)')
    parser.add_argument('--no_log_to_file', dest='log_to_file', action='store_false', help='prevent logging visdom events')

    args = parser.parse_args()


    from datetime import datetime

    arg_diff = {}
    for key in vars(args).keys():
        if vars(args)[key] != parser.get_default(key):
            arg_diff[key] = vars(args)[key]
    if len(arg_diff) == 0:
        arg_diff_str = "DEFAULT"
    else:
        arg_diff_str = " ".join(["{}={}".format(k,v) for k,v in arg_diff.items()])

    nick = f"{args.dset}__{args.arch}__{datetime.now().strftime('%m-%d %H:%M:%S')}__{arg_diff_str}"

    train(args, nick, arg_diff)
