import torchvision
import torch
import numpy as np
import SNGANDiffAug



gen = np.load('./pretrained_models/ResNetGenerator_850000.npz')
keys = gen.files
keys.sort()

gennew = SNGANDiffAug.Generator(G_ch=64, use_ema=True, dim_z=128, ema=True, G_attn='0', resolution=128, G_shared=True, shared_dim=120, hier=True )

st = gennew.state_dict()

mapping = {}
conv_mapping = {'c1' : 'conv1' , 'c2' : 'conv2' , 'c_sc' : 'conv_sc'}
conv_weight_mapping = {'W':'weight' , 'b' : 'bias'}
bn_mapping = {'b1':'bn1' , 'b2':'bn2'}
bn_weight_mapping = {'avg_var' : 'stored_var' , 'avg_mean' : 'stored_mean' , 'betas' : 'bias.weight' , 'gammas':'gain.weight'}
for each in keys:
    if 'block' in each:
        block_num = int(each.split('/')[0].split('block')[1])-2
        k = each.split('/')
        if 'b1' in k[1] or 'b2' in k[1]:
            if 'N' not in k[2]:
                map_k = 'blocks.' + str(block_num) + '.0.' + bn_mapping[k[1]] + '.' + bn_weight_mapping[k[2]]
        else:
            k = each.split('/')
            map_k = 'blocks.' + str(block_num) + '.0.' + conv_mapping[k[1]] + '.' + conv_weight_mapping[k[2]]
            mapping[each] = map_k
    elif 'l' in each.split('/')[0]:
        if each.split('/')[0].split('l')[1] == '7' and 'W' in each:
            mapping[each] = 'output_layer.2.weight'
        if each.split('/')[0].split('l')[1] == '7' and 'b' in each:
            mapping[each] = 'output_layer.2.bias'
    elif 'b' in each.split('/')[0]:
        if each.split('/')[0].split('b')[1] == '7' and 'gamma' in each:
            mapping[each] = 'output_layer.0.gain'
        if each.split('/')[0].split('b')[1] == '7' and 'beta' in each:
            mapping[each] = 'output_layer.0.bias'
        if each.split('/')[0].split('b')[1] == '7' and 'avg_mean' in each:
            mapping[each] = 'output_layer.0.stored_mean'
        if each.split('/')[0].split('b')[1] == '7' and 'avg_var' in each:
            mapping[each] = 'output_layer.0.stored_var'
    


for key , value in mapping.items():
    if gen[key].shape != st[value].shape:
        print("$$$$$$$" ,key , value, gen[key].shape , st[value].shape)



new_st = {}
for key , value in mapping.items():
    new_st[value] =  torch.from_numpy(gen[key])

torch.save(new_st, './pretrained_models/sngan.pth')


Dnew = SNGANDiffAug.Discriminator(D_ch=64, D_attn='0', resolution=128, D_wide=False ,prior_dim = 0)
st = Dnew.state_dict()


D = np.load('./pretrained_models/SNResNetProjectionDiscriminator_850000.npz')
keys = D.files
keys.sort()


mapping = {}
conv_mapping = {'c1' : 'conv1' , 'c2' : 'conv2' , 'c_sc' : 'conv_sc' }
conv_weight_mapping = {'W':'weight' , 'b' : 'bias' , 'u': 'u0'}
bn_mapping = {'b1':'bn1' , 'b2':'bn2'}
bn_weight_mapping = {'avg_var' : 'stored_var' , 'avg_mean' : 'stored_mean' , 'betas' : 'bias.weight' , 'gammas':'gain.weight'}
for each in keys:
    if 'block' in each:
        block_num = int(each.split('/')[0].split('block')[1])-1
        k = each.split('/')
        if 'b1' in k[1] or 'b2' in k[1]:
            if 'N' not in k[2]:
                map_k = 'blocks.' + str(block_num) + '.0.' + bn_mapping[k[1]] + '.' + bn_weight_mapping[k[2]]
        else:
            k = each.split('/')
            map_k = 'blocks.' + str(block_num) + '.0.' + conv_mapping[k[1]] + '.' + conv_weight_mapping[k[2]]
            mapping[each] = map_k
    elif 'l' in each.split('/')[0]:
        if each.split('/')[0].split('l')[1] == '7' and 'W' in each:
            mapping[each] = 'linear.weight'
        if each.split('/')[0].split('l')[1] == '7' and 'b' in each:
            mapping[each] = 'linear.bias'
        if each.split('/')[0].split('l')[1] == '7' and 'u' in each:
            mapping[each] = 'linear.u0'



for key , value in mapping.items():
    print(key , value, D[key].shape , st[value].shape)
    if D[key].shape != st[value].shape:
        print("#######", key , value, D[key].shape , st[value].shape)

new_st = {}
for key , value in mapping.items():
    new_st[value] = torch.from_numpy(D[key])
    
Dnew.load_state_dict(new_st,strict =False)
torch.save(new_st , './pretrained_models/SN_discriminator.pt')


