-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
-- Usage: convert_model.lua <model_epoch1.th7>
require 'torch'
local fairseq = require 'fairseq'

model = torch.load(arg[1])

function find_weight_norm(container, module)
  for _, wn in ipairs(container:listModules()) do
    if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
      return wn
    end
  end
end

function push_state(dict, key, module)
  if torch.type(module) == 'nn.Linear' then
    local wn = find_weight_norm(model.module, module)
    assert(wn)
    dict[key .. '.weight_v'] = wn.v:float()
    dict[key .. '.weight_g'] = wn.g:float()
  elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
    local wn = find_weight_norm(model.module, module)
    assert(wn)
    local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
    dict[key .. '.weight_v'] = v
    dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
  else
    dict[key .. '.weight'] = module.weight:float()
  end
  if module.bias then
    dict[key .. '.bias'] = module.bias:float()
  end
end

encoder_dict = {}
decoder_dict = {}
combined_dict = {}

function encoder_state(encoder)
  luts = encoder:findModules('nn.LookupTable')
  push_state(encoder_dict, 'embed_tokens', luts[1])
  push_state(encoder_dict, 'embed_positions', luts[2])

  fcs = encoder:findModules('nn.Linear')
  assert(#fcs >= 2)
  local nInputPlane = fcs[1].weight:size(1)
  push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
  push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))

  for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
    push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
    if nInputPlane ~= module.weight:size(3) / 2 then
      push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
    end
    nInputPlane = module.weight:size(3) / 2
  end
  assert(#fcs == 0)
end

function decoder_state(decoder)
  luts = decoder:findModules('nn.LookupTable')
  push_state(decoder_dict, 'embed_tokens', luts[1])
  push_state(decoder_dict, 'embed_positions', luts[2])

  fcs = decoder:findModules('nn.Linear')
  local nInputPlane = fcs[1].weight:size(1)
  push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
  push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
  push_state(decoder_dict, 'fc3', fcs[#fcs])

  table.remove(fcs, #fcs)
  table.remove(fcs, #fcs)

  for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
    if nInputPlane ~= module.weight:size(3) / 2 then
      push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
    end
    nInputPlane = module.weight:size(3) / 2

    local prefix = 'attention.' .. tostring(i - 1)
    push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
    push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
    push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
  end
  assert(#fcs == 0)
end


_encoder = model.module.modules[2]
_decoder = model.module.modules[3]

encoder_state(_encoder)
decoder_state(_decoder)

for k, v in pairs(encoder_dict) do
  combined_dict['encoder.' .. k] = v
end
for k, v in pairs(decoder_dict) do
  combined_dict['decoder.' .. k] = v
end


torch.save('state_dict.t7', combined_dict)
