import yaml
from model import Gen, Dis, GAN
import mynn
import torch
import torch.nn as nn
from utils import CUDA


def append_activ(activ, seq):
    if activ == 'LeakyReLU':
        seq.append(mynn.LeakyReLU(0.2))
    elif activ == 'Tanh':
        seq.append(mynn.Tanh())
    elif activ == 'None':
        pass
    else:
        raise NotImplementedError


def make_Linear(ns, ls, curr_dim, is_last):
    next_dim = ls['n_features']
    seq = []
    if 'anchored' in ls.keys() and ls['anchored']:
        lyr = mynn.Linear(curr_dim, next_dim,
                          anchor_scale=ns['anchor_scale'],
                          adaptive_scale=ns['adaptive_scale'],
                          n_anchor=ns['n_anchor'])
    else:
        lyr = nn.Linear(curr_dim, next_dim)
    seq.append(lyr)

    if is_last:
        append_activ(ns['last_activation'], seq)
    else:
        if ns['bn']:
            seq.append(mynn.BatchNorm1d(next_dim))
        append_activ(ns['activation'], seq)

    return seq, next_dim


def make_ConvTranspose2d(ns, ls, curr_dim, is_last):
    next_dim = ls['n_channels']
    seq = []
    lyr = mynn.ConvTranspose2d(curr_dim, next_dim, ls['kernel_size'],
                                 stride=ls['stride'], padding=ls['padding'])
    seq.append(lyr)

    if is_last:
        append_activ(ns['last_activation'], seq)
    else:
        if ns['bn']:
            seq.append(mynn.BatchNorm2d(next_dim))
        append_activ(ns['activation'], seq)
    return seq, next_dim


def make_Conv2d(ns, ls, curr_dim, is_last):
    next_dim = ls['n_channels']
    seq = []
    lyr = nn.Conv2d(curr_dim, next_dim, ls['kernel_size'],
                                 stride=ls['stride'], padding=ls['padding'])
    seq.append(lyr)

    if is_last:
        append_activ(ns['last_activation'], seq)
    else:
        if ns['bn']:
            seq.append(mynn.BatchNorm2d(next_dim))
        append_activ(ns['activation'], seq)
    return seq, next_dim


def make_Net(ns):
    net = []
    curr_dim = ns['in_shape'][0]
    for li, ls in enumerate(ns['net']):
        is_last = (li == len(ns['net'])-1)

        if ls['type'] == 'Linear':
            seq, curr_dim = make_Linear(ns, ls, curr_dim, is_last)
            net.extend(seq)
        elif ls['type'] == 'Reshape':
            net.append(mynn.Reshape(ls['in_shape'], ls['out_shape']))
            curr_dim = ls['out_shape'][0]
        elif ls['type'] == 'ConvTranspose2d':
            seq, curr_dim = make_ConvTranspose2d(ns, ls, curr_dim, is_last)
            net.extend(seq)
        elif ls['type'] == 'Conv2d':
            seq, curr_dim = make_Conv2d(ns, ls, curr_dim, is_last)
            net.extend(seq)
    return net


def make_sampler(shape, dist='uniform', **kwargs):
    kwargs.setdefault('min', -1.);  min = kwargs['min']
    kwargs.setdefault('max', 1.);   max = kwargs['max']
    kwargs.setdefault('mean', 0.);  mean = kwargs['mean']
    kwargs.setdefault('std', 1.);   std = kwargs['std']
    if dist not in ['uniform', 'normal']:
        raise NotImplemented("only 'uniform' or 'normal' is supported")

    # if only size1 is given, returns batch of size1.
    # if both size1 and size2 are given, returns size2 distinct samples, replicated size1 times
    def sampler(size1, size2=None):
        if size2 is not None:
            size1, size2 = size2, size1

        if dist=='uniform':
            samples = min+(max-min)*torch.rand(size1, *shape, device=CUDA())
        elif dist=='normal':
            samples = mean + std*torch.randn(size1, *shape, device=CUDA())

        if size2 is None:
            return samples
        else:
            return samples.unsqueeze(0).repeat(size2, 1, *[1]*(len(shape))).view(size1*size2, *shape)
    return sampler


def make_GAN(arch, vis, nick):
    gen_spec = arch['gen']
    dis_spec = arch['dis']

    gen = Gen(*make_Net(gen_spec),
              z_sampler=make_sampler(shape=gen_spec['in_shape'],
                                     dist=gen_spec['latent_dist']),
              in_shape=gen_spec['in_shape'],
              out_shape=gen_spec['out_shape'],
              n_anchor=gen_spec['n_anchor'])

    dis = Dis(*make_Net(dis_spec),
              in_shape=dis_spec['in_shape'],
              out_shape=dis_spec['out_shape'],
              form=dis_spec['form'])

    gen.cuda();     dis.cuda()
    gan = GAN(gen, dis, vis, nick)
    return gan


if __name__ == "__main__":
    gen, dis = make_GAN('archs/mnist.yaml', None)
    print(gen)
    print(dis)
