# -*- coding: utf-8 -*-
"""change_width.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1w4-Yn7S2BevCxQ8OAEK3Ixmw34aLGcsN
"""
from collections import OrderedDict

from os.path import exists
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F


def almost_even_split(parent, child, split_ratio = 1/2, delta_scale = 0, perturb = 0.):
    # Right now we only deal with expanding the width of the neural networks by 2
    torch.manual_seed(103)
    assert isinstance(parent, OrderedDict) and isinstance(child, OrderedDict)
    out = OrderedDict()
    for k in parent.keys():
        assert k in child
        p = parent[k]
        c = child[k]
        if 'bn' in k or ('shortcut' in k and ('running_mean' in k or 'running_var' in k or 'num_batches' in k)):
            out[k] = c.clone()
            continue
        new = c.clone()
        p_od = p.shape[0]
        c_od = c.shape[0]
        if p.dim() == 1 and p_od == c_od:
            new.data = p.clone()
            out[k] = new
            continue
        if p.dim() == 1 and p_od != c_od:
            delta = F.normalize(torch.randn(p_od), p=2, dim=-1)
            delta *= delta_scale
            new.data[:p_od, ...] = p.data.clone()
            new.data[p_od:, ...] = new[:p_od, ...].data.clone() # Need to be fixed
            new.data[:p_od, ...] += delta
            new.data[p_od:, ...] -= delta
            out[k] = new
            continue
        
        p_od, p_id = p.shape[0], p.shape[1]
        c_od, c_id = c.shape[0], c.shape[1]
        print(c.shape)
        if c_id == p_id and c_od > p_od: # Happens for input layer
            new.data[:p_od, ...] = p.data.clone()
            new.data[p_od:, ...] = new[:p_od, ...].data.clone()
        else:
            passive_split = c_id > p_id
            if passive_split:
                assert c_id == 2 * p_id  # we can only process splitting by 2x now
                new.data[:p_od, :p_id, ...] = p.data.clone() * split_ratio
                new.data[:p_od, p_id:, ...] = p.data.clone() * (1 - split_ratio)

            active_split = c_od > p_od
            if active_split:
                assert c_od == 2 * p_od  # we can only process splitting by 2x now
                # device = c.get_device()
                shape = [*c.shape]
                shape[0] = p_od
                shape[1] = c_id
                delta = F.normalize(torch.randn(*shape), p=2, dim=-1)
                delta *= delta_scale
                
                # if passive_split:
                #     new.data[:p_od, :p_id, ...] = p.data.clone() # Need to be fixed
                new.data[p_od:, ...] = new[:p_od, ...].data.clone()
                new.data[:p_od, ...] += delta
                new.data[p_od:, ...] -= delta
        out[k] = new
    return out


def shrink_width(parent, child):
  # Currently only support shrink width by 2
  # Idea: drop sparest map, maybe increase the weight by a factor of 2
  # Also can consider adding the masks and weights. 
  pass

def replicate_mask_state_dict(source_mask, target_mask):
    out = OrderedDict()
    for (k1,v1), (k2,v2) in zip(source_mask.items(), target_mask.items()):
#     if 'fc' in k1:
#         output[k1] = v2
#     else:
        assert k1 == k2
        repeats = [int(v2.shape[i] / v1.shape[i]) for i in range(len(v1.shape))]
        with torch.no_grad():
            out[k1] = v1.repeat(*repeats)
    return out

def main():
    parent_dir = "/u/hy6385/open_lth_data/lottery_a050d4cded0448e379626e47a9e737e0" # lenet_150_50
    child_dir = "/u/hy6385/open_lth_data/lottery_108ebe7f8e2dbe540dbed0e9011edec1" # lenet_300_100
    parent_model = torch.load(parent_dir + "/replicate_1/level_10/main/model_ep0_it0.pth")
    child_model = torch.load(child_dir + "/replicate_1/level_0/main/model_ep0_it0.pth")
    source_mask = torch.load(parent_dir + "/replicate_1/level_10/main/mask.pth")
    target_mask = torch.load(child_dir + "/replicate_1/level_0/main/mask.pth")
    # parser = argparse.ArgumentParser(conflict_handler='resolve')
    # parser.add_argument("--parent_directory", type=str, required=True, help="Provide the directory of the parent model and mask.")
    # parser.add_argument("--parent_filename", type=str, required=True, help="The file name of the saved model.")
    # parser.add_argument("--child_directory", type=str, required=True, help="Provide the directory of the parent model and mask.")
    # parser.add_argument("--child_filename", type=str, required=True, help="The file name of the saved model.")
    
    # change the replicate whenever do a new experiment
    replicate = 121
    split_ratio = 1/2
    delta_scale = 1e-2

    lenet300double_dict = almost_even_split(parent_model, child_model, delta_scale=delta_scale)
    lenet300double_mask = replicate_mask_state_dict(source_mask, target_mask)
    if exists(child_dir + "/replicate_{}/level_10/main/model_ep0_it0.pth".format(replicate)):
        raise RuntimeError("The file already exists. Choose a new replicate number. ")
    torch.save(lenet300double_dict, child_dir + "/replicate_{}/level_10/main/model_ep0_it0.pth".format(replicate))
    torch.save(lenet300double_mask, child_dir + "/replicate_{}/level_10/main/mask.pth".format(replicate))

    with open(child_dir+"/replicate_{}/change_width_details".format(replicate), "w") as f:
        # f.write("direct duplicate")
        f.write("split_ratio: {} \n".format(split_ratio) )
        f.write("delta_scale: {} \n".format(delta_scale) )

if __name__ == '__main__':
    main()
