# -*- coding: utf-8 -*-
"""Random_Split.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1eSBwcRorpHI84rs0a2h7tIWwWnJWFwBA
"""

# Commented out IPython magic to ensure Python compatibility.
# %load_ext autoreload

# Commented out IPython magic to ensure Python compatibility.
# %autoreload 2

import numpy as np

from collections import OrderedDict

import matplotlib
import matplotlib.pyplot as plt

font = {'size'   : 18}
matplotlib.rc('font', **font)

import os
from pathlib import Path

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import save_state_dict

def extract_test_accs_from_log(logfile):
    test_accs = []
    with open(logfile, 'r') as f:
        lines = f.readlines()
    for l in lines:
        if 'accuracy' in l:
            test_accs.append(100.0*float(l.split(',')[-1]))
    return test_accs

def extract_test_accs_from_replicate(replicate_dir):
    levels_dir = [f for f in glob.glob(os.path.join(replicate_dir, 'level_*')) if f.split('_')[-1].isdigit()]
    accs = []
    for level in range(len(levels_dir)):
        accs.append(max(extract_test_accs_from_log(os.path.join(replicate_dir, 'level_{}/main/logger'.format(level)))))
    return accs

cifar_resnet_20_16_dir = Path("/datadrive_c/xiaohan/open_lth_data/lottery_93bc65d66dfa64ffaf2a0ab105433a2c/replicate_1/")
cifar_resnet_20_32_dir = Path("/datadrive_c/xiaohan/open_lth_data/lottery_b7773335d251f26316a88e7effb19da8/replicate_1/")

cifar_resnet_20_16to32_rsp_start_from_level_0_dir = Path("/datadrive_c/xiaohan/open_lth_data/lottery_b7773335d251f26316a88e7effb19da8/replicate_118/")
cifar_resnet_20_16to32_rsp_start_from_level_10_dir = Path("/datadrive_c/xiaohan/open_lth_data/lottery_b7773335d251f26316a88e7effb19da8/replicate_218/")

cifar_resnet_14_16_dir = Path("/datadrive_c/xiaohan/open_lth_data/lottery_e0e2dbe02e789b00d38f912bcfdb6b97/replicate_1/")
cifar_resnet_14_32_dir = Path("/datadrive_c/xiaohan/open_lth_data/lottery_b33926c9afc6c6667a4437379b7e823c/replicate_1/")


## Rewinding - 500it

cifar_resnet_20_16_rw500_dir = Path("./lottery_23b644efaef60c49ca88fc5e37e2595a/replicate_1/")
cifar_resnet_20_32_rw500_dir = Path("./lottery_e83cd7fb6f734bac485b58acd238bb22/replicate_1/")

cifar_resnet_20_16to32_rw500_rsp_start_from_level_0_dir  = Path("./lottery_e83cd7fb6f734bac485b58acd238bb22/replicate_118/")
cifar_resnet_20_16to32_rw500_rsp_start_from_level_6_dir  = Path("./lottery_e83cd7fb6f734bac485b58acd238bb22/replicate_119/")
cifar_resnet_20_16to32_rw500_rsp_start_from_level_10_dir = Path("./lottery_e83cd7fb6f734bac485b58acd238bb22/replicate_120/")


## Rewinding - 1000it

cifar_resnet_20_16_rw1000_dir = Path("./lottery_231c88c353aecc912eada550c58eff10/replicate_1/")
cifar_resnet_20_32_rw1000_dir = Path("./lottery_0aebeaf1c5129110bc2f91ab9dc99578/replicate_1/")

cifar_resnet_20_16to32_rw1000_rsp_start_from_level_0_dir  = Path("./lottery_0aebeaf1c5129110bc2f91ab9dc99578/replicate_118/")
cifar_resnet_20_16to32_rw1000_rsp_start_from_level_6_dir  = Path("./lottery_0aebeaf1c5129110bc2f91ab9dc99578/replicate_119/")
cifar_resnet_20_16to32_rw1000_rsp_start_from_level_10_dir = Path("./lottery_0aebeaf1c5129110bc2f91ab9dc99578/replicate_120/")


## VGG19

cifar_vgg19_32_rw_dir = Path("./lottery_edb979b6c106786c89409057251082ea/replicate_1/")
cifar_vgg19_64_rw_dir = Path("./lottery_32da8a3ac7d84f9f0dda4fd60265ded3/replicate_1/")

cifar_vgg19_32to64_rw_rsp_start_from_level_0_dir  = Path("./lottery_32da8a3ac7d84f9f0dda4fd60265ded3/replicate_118/")
cifar_vgg19_32to64_rw_rsp_start_from_level_8_dir  = Path("./lottery_32da8a3ac7d84f9f0dda4fd60265ded3/replicate_119/")
cifar_vgg19_32to64_rw_rsp_start_from_level_11_dir  = Path("./lottery_32da8a3ac7d84f9f0dda4fd60265ded3/replicate_120/")

print(max(extract_test_accs_from_log(cifar_resnet_20_32_rw1000_dir  / 'level_6/main/logger')))
print(max(extract_test_accs_from_log(cifar_resnet_20_16to32_rw1000_rsp_start_from_level_6_dir  / 'level_0/main/logger')))
print(max(extract_test_accs_from_log(cifar_resnet_20_32_rw1000_dir  / 'level_10/main/logger')))
print(max(extract_test_accs_from_log(cifar_resnet_20_16to32_rw1000_rsp_start_from_level_10_dir  / 'level_0/main/logger')))

m_20_16_rw_l10 = torch.load(cifar_resnet_20_16_rw_dir / 'level_10/main/mask.pth')
m_20_32_rw_l10 = torch.load(cifar_resnet_20_32_rw_dir / 'level_10/main/mask.pth')
m_20_16_rw_l6  = torch.load(cifar_resnet_20_16_rw_dir  / 'level_6/main/mask.pth')
m_20_32_rw_l6  = torch.load(cifar_resnet_20_32_rw_dir  / 'level_6/main/mask.pth')

# Rewinding - 1000
m_20_16_rw1000_l10 = torch.load(cifar_resnet_20_16_rw1000_dir / 'level_10/main/mask.pth')
m_20_32_rw1000_l10 = torch.load(cifar_resnet_20_32_rw1000_dir / 'level_10/main/mask.pth')
m_20_16_rw1000_l6  = torch.load(cifar_resnet_20_16_rw1000_dir  / 'level_6/main/mask.pth')
m_20_32_rw1000_l6  = torch.load(cifar_resnet_20_32_rw1000_dir  / 'level_6/main/mask.pth')

m_vgg19_w32_rw_l8  = torch.load(cifar_vgg19_32_rw_dir  / 'level_8/main/mask.pth')
m_vgg19_w64_rw_l8  = torch.load(cifar_vgg19_64_rw_dir  / 'level_8/main/mask.pth')
m_vgg19_w32_rw_l11 = torch.load(cifar_vgg19_32_rw_dir / 'level_11/main/mask.pth')
m_vgg19_w64_rw_l11 = torch.load(cifar_vgg19_64_rw_dir / 'level_11/main/mask.pth')

rewind_20_16_rw = torch.load(cifar_resnet_20_16_rw_dir / 'level_pretrain/main/model_ep1_it109.pth')
rewind_20_32_rw = torch.load(cifar_resnet_20_32_rw_dir / 'level_pretrain/main/model_ep1_it109.pth')

# Rewinding - 1000
rewind_20_16_rw1000 = torch.load(cifar_resnet_20_16_rw1000_dir / 'level_pretrain/main/model_ep2_it218.pth')
rewind_20_32_rw1000 = torch.load(cifar_resnet_20_32_rw1000_dir / 'level_pretrain/main/model_ep2_it218.pth')

rewind_vgg19_w32_rw = torch.load(cifar_vgg19_32_rw_dir / 'level_pretrain/main/model_ep0_it100.pth')
rewind_vgg19_w64_rw = torch.load(cifar_vgg19_64_rw_dir / 'level_pretrain/main/model_ep0_it100.pth')

rewind_20_16to32_rsp = random_split_state_dict(rewind_20_16_rw, rewind_20_32_rw)

rewind1000_20_16to32_rsp = random_split_state_dict(rewind_20_16_rw1000, rewind_20_32_rw1000)

rewind_vgg19_w32to64_rw_rsp = random_split_state_dict(rewind_vgg19_w32_rw, rewind_vgg19_w64_rw)

save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_0_dir / 'level_pretrain/main/model_ep1_it109.pth')
save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_0_dir / 'level_0/main/model_ep1_it109.pth')

save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_6_dir / 'level_pretrain/main/model_ep1_it109.pth')
save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_6_dir / 'level_0/main/model_ep1_it109.pth')

save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_10_dir / 'level_pretrain/main/model_ep1_it109.pth')
save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_10_dir / 'level_0/main/model_ep1_it109.pth')

# rewinding - 1,000it
save_state_dict(rewind1000_20_16to32_rsp, cifar_resnet_20_16to32_rw1000_rsp_start_from_level_0_dir / 'level_pretrain/main/model_ep2_it218.pth')
# save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_0_dir / 'level_0/main/model_ep2_it218.pth')

save_state_dict(rewind1000_20_16to32_rsp, cifar_resnet_20_16to32_rw1000_rsp_start_from_level_6_dir / 'level_pretrain/main/model_ep2_it218.pth')
# save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_6_dir / 'level_0/main/model_ep2_it218.pth')

save_state_dict(rewind1000_20_16to32_rsp, cifar_resnet_20_16to32_rw1000_rsp_start_from_level_10_dir / 'level_pretrain/main/model_ep2_it218.pth')
# save_state_dict(rewind_20_16to32_rsp, cifar_resnet_20_16to32_rw_rsp_start_from_level_10_dir / 'level_0/main/model_ep2_it218.pth')

save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_0_dir / 'level_pretrain/main/model_ep0_it100.pth')
save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_0_dir / 'level_0/main/model_ep0_it100.pth')
save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_0_dir / 'level_0/main/model_ep160_it0.pth')

save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_8_dir / 'level_pretrain/main/model_ep0_it100.pth')
save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_8_dir / 'level_0/main/model_ep0_it100.pth')
save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_8_dir / 'level_0/main/model_ep160_it0.pth')

save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_11_dir / 'level_pretrain/main/model_ep0_it100.pth')
save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_11_dir / 'level_0/main/model_ep0_it100.pth')
save_state_dict(rewind_vgg19_w32to64_rw_rsp, cifar_vgg19_32to64_rw_rsp_start_from_level_11_dir / 'level_0/main/model_ep160_it0.pth')

m_20_16to32_rw500_rsp_l6  = replicate_mask_state_dict(m_20_16_rw_l6 , m_20_32_rw_l6)
m_20_16to32_rw500_rsp_l10 = replicate_mask_state_dict(m_20_16_rw_l10, m_20_32_rw_l10)

save_state_dict(m_20_16to32_rsp_l6,  cifar_resnet_20_16to32_rw_rsp_start_from_level_6_dir  / 'level_6/main/mask.pth')
save_state_dict(m_20_16to32_rsp_l10, cifar_resnet_20_16to32_rw_rsp_start_from_level_10_dir / 'level_10/main/mask.pth')

m_20_16to32_rw1000_rsp_l6  = replicate_mask_state_dict(m_20_16_rw1000_l6 , m_20_32_rw1000_l6)
m_20_16to32_rw1000_rsp_l10 = replicate_mask_state_dict(m_20_16_rw1000_l10, m_20_32_rw1000_l10)

save_state_dict(m_20_16to32_rw1000_rsp_l6 , cifar_resnet_20_16to32_rw1000_rsp_start_from_level_6_dir  / 'level_0/main/mask.pth')
save_state_dict(m_20_16to32_rw1000_rsp_l10, cifar_resnet_20_16to32_rw1000_rsp_start_from_level_10_dir / 'level_0/main/mask.pth')

m_vgg19_w32to64_rw_rsp_l8 = replicate_mask_state_dict(m_vgg19_w32_rw_l8, m_vgg19_w64_rw_l8)
m_vgg19_w32to64_rw_rsp_l11 = replicate_mask_state_dict(m_vgg19_w32_rw_l11, m_vgg19_w64_rw_l11)

save_state_dict(m_vgg19_w32to64_rw_rsp_l8,  cifar_vgg19_32to64_rw_rsp_start_from_level_8_dir  / 'level_8/main/mask.pth')
save_state_dict(m_vgg19_w32to64_rw_rsp_l11, cifar_vgg19_32to64_rw_rsp_start_from_level_11_dir / 'level_11/main/mask.pth')

def random_split_state_dict(parent, child):
    assert isinstance(parent, OrderedDict) and isinstance(child, OrderedDict)
    out = OrderedDict()
    for k in parent.keys():
        assert k in child
        p = parent[k]
        c = child[k]
        if 'bn' in k or 'running' in k or 'num' in k or p.dim() <= 1:
            out[k] = c.clone()
            continue
        new = c.clone()
        p_od, p_id = p.shape[0], p.shape[1]
        c_od, c_id = c.shape[0], c.shape[1]
        print(c.shape)

        passive_split = c_id > p_id
        if passive_split:
            assert c_id == 2 * p_id  # we can only process splitting by 2x now
            new.data[:p_od, :p_id, ...] = p.data.clone()
            new.data[:p_od, p_id:, ...] = p.data.clone()
            new.data /= 2.0

        active_split = c_od > p_od
        if active_split:
            assert c_od == 2 * p_od  # we can only process splitting by 2x now
            device = c.get_device()
            delta1 = F.normalize(torch.randn(p_od, c_id, c.shape[-2], c.shape[-1]).to(device), p=2, dim=-1)
            delta2 = F.normalize(torch.randn(p_od, c_id, c.shape[-2], c.shape[-1]).to(device), p=2, dim=-1)
            delta1 *= 1e-2
            delta2 *= 1e-2
            
            if not passive_split:
                new.data[:p_od, :p_id, ...] = p.data.clone()
            new.data[p_od:, ...] = new[:p_od, ...].data.clone()
            new.data[:p_od, ...] += delta1
            new.data[p_od:, ...] -= delta2
        
        out[k] = new
    return out

def replicate_mask_state_dict(mask_source, target_model):
    out = OrderedDict()
    for (k1,v1), (k2,v2) in zip(mask_source.items(), target_model.items()):
#     if 'fc' in k1:
#         output[k1] = v2
#     else:
        repeats = [int(v2.shape[i] / v1.shape[i]) for i in range(len(v1.shape))]
        with torch.no_grad():
            out[k1] = v1.repeat(*repeats)
    return out

for (k1,v1), (k2,v2) in zip(m_20_16_l3.items(), m_20_32_l3.items()):
    if 'fc' in k1:
        m_20_16to32_l3[k1] = v2
    else:
        repeats = [int(v2.shape[i] / v1.shape[i]) for i in range(len(v1.shape))]
        with torch.no_grad():
            m_20_16to32_l3[k1] = v1.repeat(*repeats)

