#!/usr/bin/env python

import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
from config import config


def get_scheduler(optimizer, args):

    if args.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + args.epochs - args.niter) / float(args.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif args.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=args.niter_decay, gamma=0.1)
    elif args.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif args.lr_policy == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.niter, eta_min=0)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', args.lr_policy)

    return scheduler


def update_learning_rate(scheduler, optimizer):
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']


def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


def init_weights(net, gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.xavier_normal_(m.weight.data, gain=gain)

            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
            elif classname.find('BatchNorm2d') != -1:
                init.normal_(m.weight.data, 1.0, gain)
                init.constant_(m.bias.data, 0.0)

    print('--Initialize network\'s weights with {}'.format('xavier'))
    net.apply(init_func)


def init_net(net, init_gain=0.02, gpu_id='cuda:0'):
    init_weights(net, gain=init_gain)
    return net


# define Generator
def define_G(input_nc, output_nc, ngf, norm='batch', use_dropout=True, init_gain=0.02, gpu_id='cuda:0', n_blocks=9, use_lstm = config['lstm'], lstm_layers = 0, use_params=False, args={}):
    # loading generator with weight from given path
    if args.model_path is not None:
        net = torch.load(args.model_path, map_location=lambda storage, loc: storage).cpu()
        return net
    else:
        norm_layer = get_norm_layer(norm_type=norm)

        params = args.model_type == 'vd' or args.model_type == 's'
        # if twonet in given model name the lstm is in the 7th layer
        if 'twonet' in args.model_name:
            if 'after' in args.model_name: beforeConv = False
            else: beforeConv = True
            net = UnetGeneratorLSTMBetween(input_nc, output_nc, num_downs=n_blocks, ngf = ngf, norm_layer=norm_layer, use_dropout=use_dropout, params = params, use_lstm = use_lstm, lstm_layers = lstm_layers, use_params=use_params, atDecoder = True, beforeConv = beforeConv)
        elif 'unet' in args.model_name:
            net = UnetGenerator(input_nc, output_nc, num_downs=n_blocks, ngf = ngf, norm_layer=norm_layer, use_dropout=use_dropout, params = params, use_lstm = use_lstm, lstm_layers = lstm_layers, use_params=use_params)
        else:
            raise Exception("Unknowns model: {}:".format(args.model_name))
        return init_net(net, init_gain, gpu_id=gpu_id)


# define Discriminator 
def define_D(input_nc, ndf, n_layers_D=3, norm='batch', use_sigmoid=False, init_gain=0.02, gpu_id='cuda:0', args={}, loadStored = True):
    norm_layer = get_norm_layer(norm_type=norm)
    net = NLayerDiscriminator(input_nc, ndf=ndf, n_layers=n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
    return init_net(net, init_gain, gpu_id=gpu_id)


# Discriminator with patchGAN architecture
class NLayerDiscriminator(nn.Module):

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)


class UnetGenerator(nn.Module):

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, params = False, use_lstm=False, lstm_layers = 0, use_params = False):
        super(UnetGenerator, self).__init__()

        self.params = params
        self.use_lstm = use_lstm

        # calculating required size for lstm
        featureMap_width = (int) (config['input_width'] / 2**num_downs)
        featureMap_height = (int) (config['input_height'] / 2**num_downs)
        num_ft_maps = ngf * 8


        # sequentially building up the net
        # the net contains at least 6 layers
        if use_lstm: 
            lstm = LSTMblock(featureMap_height, featureMap_width, num_ft_maps, num_layers=lstm_layers, use_params = use_params)
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=lstm, norm_layer=norm_layer, innermost=True, use_lstm=self.use_lstm)
        else:
            # most inner block doesnt use dropout, this was used in the evaluation to have a better comparison between lstm/no-lstm
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_lstm=self.use_lstm)

            # alternatively the below line uses dropout
            #unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_lstm=self.use_lstm, use_dropout=use_dropout)  

        for i in range(num_downs - 6):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)


        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)

        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)

        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)

        # add the outermost layer, number of input channels input_nc
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  

    def forward(self, input, params_=None, hidden_state=None, cell_state=None):
        device=input[0].device
        if self.params:
            if params_ is None:
                input, params = input
            else:
                params = params_
            tens = []
            for i in range(input.size(0)):
                t = []
                for j in range(params.size(3)):
                    t.append(torch.ones(1, 1, input.shape[2], input.shape[3]).to(device) * params[i][0][0][j].to(device))                  
                tens.append(torch.torch.cat(t, 1))
            input = torch.cat((torch.torch.cat(tens, 0).to(device), input), 1).to(device)
        return self.model(input)


# unet layers that are use in unet with encoder- and decoder-block and skip connection
class UnetSkipConnectionBlock(nn.Module):

    def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, use_lstm=False, lstm =  None, atDecoder = False, beforeConv = True):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        self.use_lstm = use_lstm

        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        # outermost layer
        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up

        # innermost layer
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)

            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)] if submodule is not None else down + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up if submodule is not None else down + up

        # unet-layer that is used for the twonets
        elif atDecoder:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm] 
            up = [uprelu,lstm, upconv, upnorm] if beforeConv else [uprelu, upconv, lstm, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        # standard layer for the rest of the net
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

# block that uses the lstm and provides the functionality for it
# recursive/non-recursive mode has to be set with function set_recursive
# if simulation parameters should be used the function set_parameters has to be used
class LSTMblock(nn.Module):
    def __init__(self, ftMap_height, ftMap_width,num_ft_maps, num_layers=1, batch_first=False, use_params = False):
        super(LSTMblock, self).__init__()

        self.input_size = ftMap_width*ftMap_height*num_ft_maps
        if use_params:
            self.input_size += 1
        self.hidden_size = ftMap_width*ftMap_height*num_ft_maps

        self.lstm = nn.LSTM(self.input_size, self.hidden_size, num_layers,  batch_first=batch_first)

        self.hidden_state = nn.Parameter(torch.zeros(num_layers*1, 1, self.hidden_size))
        self.current_state = nn.Parameter(torch.zeros(num_layers*1, 1,self.hidden_size))

        self.inputShape = [0,1*1, 1, self.input_size]
        self.recursive = False

        if use_params:
            self.parameter = nn.Parameter(torch.zeros(1))

    # input of shape [batch_size, num_features, height, width]
    def forward(self, x, hidden_state=None, cell_state=None):
        device = x[0].device
        x_size = x.size()
        self.inputShape = x_size

        # reshaping input: the data for each timestep is reshaped into 1d for the lstm 
        xView = x.view(x_size[0], x_size[1]*x_size[2]*x_size[3])
        input = torch.unsqueeze(xView, 1)

        # if lstm uses params: parameter for each timestep is concatenated to input
        if hasattr(self, 'parameter'):
            temp = torch.cat((input,self.parameter),2)
            input = temp
        output = torch.zeros(0, 0, 0)

        # current/hidden state always starts with self.current/hidden
        # therefore self.current/hidden have to be resetted manually via function set_recursive
        currer = self.current_state
        hidder = self.hidden_state
        for i in input:
            inputSingle = torch.unsqueeze(i, 0)

            outputHold, (hidder, currer) = self.lstm(inputSingle, (hidder, currer))

            if not output.numel():
                output = outputHold
            else:
                output = torch.cat([output, outputHold], 0)

        # if current/hidden state should be retained for next batch, the lstm has to be set to recursive=true
        if self.recursive:
            self.hidden_state = nn.Parameter(hidder)
            self.current_state = nn.Parameter(currer)

        # transforming lstm output back to input-shape
        outputSqueeze = torch.squeeze(output)
        outputView = outputSqueeze.view(x_size[0], x_size[1], x_size[2], x_size[3])
        return outputView

    # LSTM has 2 modes, recursive or not
    # if recursive the hidden and current state are stored for the next application, otherwise they wont be stored
    # calling this method current and hidden state are always resetted to zero
    def set_recursive(self, recursive=False, device="cpu"):
        self.recursive = recursive
        self.hidden_state = nn.Parameter(torch.zeros((self.lstm.num_layers*1, 1, self.hidden_size), device=torch.device(device))) if self.inputShape is not None else None
        self.current_state = nn.Parameter(torch.zeros((self.lstm.num_layers*1, 1, self.hidden_size), device=torch.device(device))) if self.inputShape is not None else None

    # if lstm should use parameters, they have to be set via this method
    def set_parameter(self, parameter=0):
        if hasattr(self, 'parameter'):
            self.parameter = nn.Parameter(parameter.view(parameter.size()[0], 1,1))



class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        torch.autograd.set_detect_anomaly(True)
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)



class View(nn.Module):

    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape
    def forward(self, input):
        out = input.view(*self.shape)
        return out


class BlowBlock(nn.Module):

    def __init__(self, input_nums, size):
        super(BlowBlock, self).__init__()
        blocks= [nn.ConvTranspose2d(input_nums, size[0]*size[1], 1),
                   View((-1,1,size[0],size[1])),
                   nn.BatchNorm2d(1),
                   nn.ReLU(),
                   nn.Dropout(0.5)]

        self.model =  nn.Sequential(*blocks)

    def forward(self, input):
        return self.model(input)


# UNet-Generator where the LSTM is inside the second-innermost layer, the rest is equal
class UnetGeneratorLSTMBetween(nn.Module):

    def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True, params = False, use_lstm=False, lstm_layers = 0, use_params = False, atDecoder = False, beforeConv = True):
        super(UnetGeneratorLSTMBetween, self).__init__()
        self.params = params
        self.use_lstm = use_lstm

        featureMap_width = (int) (config['input_width'] / 2**(num_downs-1))
        featureMap_height = (int) (config['input_height'] / 2**(num_downs-1))
        num_ft_maps = ngf * 8 * 4 if not beforeConv else ngf * 8 * 2


        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, use_lstm=self.use_lstm) 

        lstm = LSTMblock(featureMap_height, featureMap_width, num_ft_maps, num_layers=lstm_layers, use_params = use_params)
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, atDecoder=atDecoder,  lstm = lstm, use_lstm=self.use_lstm, beforeConv = beforeConv)
  

        for i in range(num_downs - 6):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)

        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)

        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)

        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)

        # add the outermost layer, number of input channels input_nc
        self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  

    def forward(self, input, params_=None, hidden_state=None, cell_state=None):
        device=input[0].device
        if self.params:
            if params_ is None:
                input, params = input
            else:
                params = params_
            tens = []
            for i in range(input.size(0)):
                t = []
                for j in range(params.size(3)):
                    t.append(torch.ones(1, 1, input.shape[2], input.shape[3]) * params[i][0][0][j])                    
                tens.append(torch.torch.cat(t, 1))
            input = torch.cat((torch.torch.cat(tens, 0).to(device), input), 1).to(device)
        return self.model(input)
