import pickle
import torch
from torch import nn
from mapper import latent_mappers
from models.eg3d.training.triplane import TriPlaneGenerator
from torch_utils import misc


def get_keys(d, name):
	if 'state_dict' in d:
		d = d['state_dict']
	d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
	return d_filt


class AT3DMapper(nn.Module):

	def __init__(self, opts):
		super(AT3DMapper, self).__init__()
		self.opts = opts
		# Define architecture
		self.mapper = latent_mappers.LatentMapper(self.opts)
		self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
		# Load weights if needed
		self.load_weights()

		with open(self.opts.stylegan_weights, 'rb') as f:
			pretrained_generator = pickle.load(f)['G_ema']
		# pretrained_generator = torch.load(self.opts.stylegan_weights, map_location='cpu')

		self.decoder = TriPlaneGenerator(*pretrained_generator.init_args, **pretrained_generator.init_kwargs).requires_grad_(False)
		misc.copy_params_and_buffers(pretrained_generator, self.decoder, require_all=False)
		# self.decoder.load_state_dict(pretrained_generator, strict=False)
		self.decoder.requires_grad_(True)
		self.decoder.neural_rendering_resolution = pretrained_generator.neural_rendering_resolution
		self.decoder.rendering_kwargs = pretrained_generator.rendering_kwargs
		del pretrained_generator

	def load_weights(self):
		if self.opts.checkpoint_path is not None:
			print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
			ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
			self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)
			