"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.

This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License. To view a copy of this license, visit
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""

import os
import torch


class CheckpointIO(object):
    def __init__(self, fname_template, **kwargs):
        os.makedirs(os.path.dirname(fname_template), exist_ok=True)
        self.fname_template = fname_template
        self.module_dict = kwargs

    def register(self, **kwargs):
        self.module_dict.update(kwargs)

    def save(self, step):
        fname = self.fname_template.format(step)
        print('Saving checkpoint into %s...' % fname)
        outdict = {}
        for name, module in self.module_dict.items():
            outdict[name] = module.state_dict()
        torch.save(outdict, fname)

    def load(self, step, kd=False):
        fname = self.fname_template.format(step)
        assert os.path.exists(fname), fname + ' does not exist!'
        print('Loading checkpoint from %s...' % fname)
        if torch.cuda.is_available():
            module_dict = torch.load(fname)
        else:
            module_dict = torch.load(fname, map_location=torch.device('cpu'))


        if not kd:
            for name, module in self.module_dict.items():
                module.load_state_dict(module_dict[name])
        else:
            for name, module in self.module_dict.items():
                if name =='generator':
                    module.load_state_dict(module_dict[name])
                elif name =='mapping_network':
                    for name_, param in module.named_parameters():
                        pass

                       # if 'shared' in  name_  and  'unshared' not in  name_:
                       #     mapped_name = name_ 
                       #     embed_w = module_dict[name][mapped_name]
                       #     assert param.data.shape == torch.nn.Parameter(embed_w.cpu()).data.cuda().shape
                       #     param.data = torch.nn.Parameter(embed_w.cpu()).data.cuda()
                       # elif 'unshared' in  name_ :
                       #     mapped_name = 'unshared.0.'+ name_.split('.')[2] +'.' + name_.split('.')[3] 
                       #     embed_w = module_dict[name][mapped_name]
                       #     assert param.data.shape == torch.nn.Parameter(embed_w.cpu()).data.cuda().shape
                       #     param.data = torch.nn.Parameter(embed_w.cpu()).data.cuda()

                elif name =='style_encoder':
                    for name_, param in module.named_parameters():
                        if 'shared' in  name_  and  'unshared' not in  name_:
                            mapped_name = 'encode.' + name_.split('shared.')[-1]
                            embed_w = module_dict[name][mapped_name]
                            assert param.data.shape == torch.nn.Parameter(embed_w.cpu()).data.cuda().shape
                            param.data = torch.nn.Parameter(embed_w.cpu()).data.cuda()
                        elif 'unshared' in  name_ :
                            if 'weight' in name_:
                                mapped_name = 'unshared.0.weight' 
                            else:
                                mapped_name = 'unshared.0.bias' 

                            embed_w = module_dict[name][mapped_name] 
                            assert param.data.shape == torch.nn.Parameter(embed_w.cpu()).data.cuda().shape
                            param.data = torch.nn.Parameter(embed_w.cpu()).data.cuda()
                elif name =='discriminator':
                    for name_, param in module.named_parameters():
                        top_layer_index = name_.split('.')[1]

                    for name_, param in module.named_parameters():
                        if top_layer_index != name_.split('.')[1]:
                            mapped_name = name_ 
                            mapped_name = 'encode' + name_.split('main')[-1]
                            embed_w = module_dict[name][mapped_name]
                            assert param.data.shape == torch.nn.Parameter(embed_w.cpu()).data.cuda().shape
                            param.data = torch.nn.Parameter(embed_w.cpu()).data.cuda()
                       # elif top_layer_index == name_.split('.')[1]:
                       #     mapped_name = name_ 
                       #     mapped_name = 'encode' + name_.split('main')[-1]
                       #     embed_w = module_dict[name][mapped_name]
                       #     first_class_matrix=torch.split(embed_w, 1)[0]
                       #     num_class = (param.data.size()[0])
                       #     all_class_matrix=[first_class_matrix]*num_class
                       #     embed_w = torch.cat(all_class_matrix, dim=0)
                       #     assert param.data.shape == torch.nn.Parameter(embed_w.cpu()).data.cuda().shape
                       #     param.data = torch.nn.Parameter(embed_w.cpu()).data.cuda()


