from __future__ import print_function

import argparse
from functools import reduce

import torch
assert torch.__version__.split('.')[0] == '0', 'Only working on PyTorch 0.x.x'
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.serialization import load_lua


class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input


class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))


class LambdaMap(LambdaBase):
    def forward(self, input):
        # result is Variables list [Variable1, Variable2, ...]
        return list(map(self.lambda_func, self.forward_prepare(input)))


class LambdaReduce(LambdaBase):
    def forward(self, input):
        # result is a Variable
        return reduce(self.lambda_func, self.forward_prepare(input))


def copy_param(m, n):
    if m.weight is not None: n.weight.data.copy_(m.weight)
    if m.bias is not None: n.bias.data.copy_(m.bias)
    if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean)
    if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var)


def add_submodule(seq, *args):
    for n in args:
        seq.add_module(str(len(seq._modules)), n)


def lua_recursive_model(module, seq):
    for m in module.modules:
        name = type(m).__name__
        real = m
        if name == 'TorchObject':
            name = m._typename.replace('cudnn.', '')
            m = m._obj

        if name == 'SpatialConvolution':
            if not hasattr(m, 'groups'): m.groups = 1
            n = nn.Conv2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH),
                          (m.dW, m.dH), (m.padW, m.padH), 1, m.groups,
                          bias=(m.bias is not None))
            copy_param(m, n)
            add_submodule(seq, n)
        elif name == 'SpatialBatchNormalization':
            n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum,
                               m.affine)
            copy_param(m, n)
            add_submodule(seq, n)
        elif name == 'ReLU':
            n = nn.ReLU()
            add_submodule(seq, n)
        elif name == 'SpatialMaxPooling':
            n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
                             ceil_mode=m.ceil_mode)
            add_submodule(seq, n)
        elif name == 'SpatialAveragePooling':
            n = nn.AvgPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
                             ceil_mode=m.ceil_mode)
            add_submodule(seq, n)
        elif name == 'SpatialUpSamplingNearest':
            n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
            add_submodule(seq, n)
        elif name == 'View':
            n = Lambda(lambda x: x.view(x.size(0), -1))
            add_submodule(seq, n)
        elif name == 'Linear':
            # Linear in pytorch only accept 2D input
            n1 = Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x)
            n2 = nn.Linear(m.weight.size(1), m.weight.size(0),
                           bias=(m.bias is not None))
            copy_param(m, n2)
            n = nn.Sequential(n1, n2)
            add_submodule(seq, n)
        elif name == 'Dropout':
            m.inplace = False
            n = nn.Dropout(m.p)
            add_submodule(seq, n)
        elif name == 'SoftMax':
            n = nn.Softmax()
            add_submodule(seq, n)
        elif name == 'Identity':
            n = Lambda(lambda x: x)  # do nothing
            add_submodule(seq, n)
        elif name == 'SpatialFullConvolution':
            n = nn.ConvTranspose2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH),
                                   (m.dW, m.dH), (m.padW, m.padH))
            add_submodule(seq, n)
        elif name == 'SpatialReplicationPadding':
            n = nn.ReplicationPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
            add_submodule(seq, n)
        elif name == 'SpatialReflectionPadding':
            n = nn.ReflectionPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
            add_submodule(seq, n)
        elif name == 'Copy':
            n = Lambda(lambda x: x)  # do nothing
            add_submodule(seq, n)
        elif name == 'Narrow':
            n = Lambda(
                lambda x, a=(m.dimension, m.index, m.length): x.narrow(*a))
            add_submodule(seq, n)
        elif name == 'SpatialCrossMapLRN':
            lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size, m.alpha, m.beta,
                                                     m.k)
            n = Lambda(lambda x, lrn=lrn: lrn.forward(x))
            add_submodule(seq, n)
        elif name == 'Sequential':
            n = nn.Sequential()
            lua_recursive_model(m, n)
            add_submodule(seq, n)
        elif name == 'ConcatTable':  # output is list
            n = LambdaMap(lambda x: x)
            lua_recursive_model(m, n)
            add_submodule(seq, n)
        elif name == 'CAddTable':  # input is list
            n = LambdaReduce(lambda x, y: x + y)
            add_submodule(seq, n)
        elif name == 'Concat':
            dim = m.dimension
            n = LambdaReduce(lambda x, y, dim=dim: torch.cat((x, y), dim))
            lua_recursive_model(m, n)
            add_submodule(seq, n)
        elif name == 'TorchObject':
            print('Not Implement', name, real._typename)
        else:
            print('Not Implement', name)


def lua_recursive_source(module):
    s = []
    for m in module.modules:
        name = type(m).__name__
        real = m
        if name == 'TorchObject':
            name = m._typename.replace('cudnn.', '')
            m = m._obj

        if name == 'SpatialConvolution':
            if not hasattr(m, 'groups'): m.groups = 1
            s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(
                m.nInputPlane,
                m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
                1, m.groups, m.bias is not None)]
        elif name == 'SpatialBatchNormalization':
            s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(
                m.running_mean.size(0), m.eps, m.momentum, m.affine)]
        elif name == 'ReLU':
            s += ['nn.ReLU()']
        elif name == 'SpatialMaxPooling':
            s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format(
                (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
        elif name == 'SpatialAveragePooling':
            s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format(
                (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
        elif name == 'SpatialUpSamplingNearest':
            s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(
                m.scale_factor)]
        elif name == 'View':
            s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
        elif name == 'Linear':
            s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
            s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),
                                                   m.weight.size(0),
                                                   (m.bias is not None))
            s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)]
        elif name == 'Dropout':
            s += ['nn.Dropout({})'.format(m.p)]
        elif name == 'SoftMax':
            s += ['nn.Softmax()']
        elif name == 'Identity':
            s += ['Lambda(lambda x: x), # Identity']
        elif name == 'SpatialFullConvolution':
            s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
                                                              m.nOutputPlane,
                                                              (m.kW, m.kH),
                                                              (m.dW, m.dH), (
                                                              m.padW, m.padH))]
        elif name == 'SpatialReplicationPadding':
            s += ['nn.ReplicationPad2d({})'.format(
                (m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
        elif name == 'SpatialReflectionPadding':
            s += ['nn.ReflectionPad2d({})'.format(
                (m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
        elif name == 'Copy':
            s += ['Lambda(lambda x: x), # Copy']
        elif name == 'Narrow':
            s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format(
                (m.dimension, m.index, m.length))]
        elif name == 'SpatialCrossMapLRN':
            lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format(
                (m.size, m.alpha, m.beta, m.k))
            s += [
                'Lambda(lambda x,lrn={}: Variable(lrn.forward(x)))'.format(
                    lrn)]

        elif name == 'Sequential':
            s += ['nn.Sequential( # Sequential']
            s += lua_recursive_source(m)
            s += [')']
        elif name == 'ConcatTable':
            s += ['LambdaMap(lambda x: x, # ConcatTable']
            s += lua_recursive_source(m)
            s += [')']
        elif name == 'CAddTable':
            s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
        elif name == 'Concat':
            dim = m.dimension
            s += [
                'LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(
                    m.dimension)]
            s += lua_recursive_source(m)
            s += [')']
        else:
            s += '# ' + name + ' Not Implement,\n'
    s = map(lambda x: '\t{}'.format(x), s)
    return s


def simplify_source(s):
    s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'),
            s)
    s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s)
    s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s)
    s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s)
    s = map(lambda x: x.replace('),#Conv2d', ')'), s)
    s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s)
    s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s)
    s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s)
    s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s)
    s = map(lambda x: x.replace('),#MaxPool2d', ')'), s)
    s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s)
    s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s)
    s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s)
    s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s)

    s = map(lambda x: '{},\n'.format(x), s)
    s = map(lambda x: x[1:], s)
    s = reduce(lambda x, y: x + y, s)
    return s


def torch_to_pytorch(t7_filename, outputname=None):
    model = load_lua(t7_filename, unknown_classes=True)
    if type(model).__name__ == 'hashable_uniq_dict': model = model.model
    model.gradInput = None
    slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
    s = simplify_source(slist)
    header = '''
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))
'''
    varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-',
                                                                       '_')
    s = '{}\n\n{} = {}'.format(header, varname, s[:-2])

    if outputname is None: outputname = varname
    with open(outputname + '.py', "w") as pyfile:
        pyfile.write(s)

    n = nn.Sequential()
    lua_recursive_model(model, n)
    torch.save(n.state_dict(), outputname + '.pth')


parser = argparse.ArgumentParser(
    description='Convert torch t7 model to pytorch')
parser.add_argument('--model', '-m', type=str, required=True,
                    help='torch model file in t7 format')
parser.add_argument('--output', '-o', type=str, default=None,
                    help='output file name prefix, xxx.py xxx.pth')
args = parser.parse_args()

torch_to_pytorch(args.model, args.output)
