# coding: utf-8

import yaml

def parse_yaml(paths):
	model_configs = dict()
	loss_weights = None
	extra_kwargs = dict()
	for path in paths:
		with open(path, 'r') as f:
			config_f = yaml.safe_load(f)
		if 'quantizer' in config_f:
			if not 'quantizer' in model_configs: # Initialize
				model_configs['quantizer'] = dict(init_args=dict())
			if 'module_name' in config_f['quantizer']:
				model_configs['quantizer']['module_name'] = config_f['quantizer']['module_name']
			if 'init_args' in config_f['quantizer']:
				model_configs['quantizer']['init_args'].update(config_f['quantizer']['init_args'])
		if 'backbone' in config_f:
			model_configs['encoder'] = dict(module_name='Encoder',init_args=config_f['backbone']['init_args'])
			model_configs['decoder'] = dict(module_name='Decoder',init_args=config_f['backbone']['init_args'])
		if 'wav2vec2' in config_f:
			model_configs['wav2vec2'] = dict(module_name='Wav2Vec2',init_args=config_f['wav2vec2']['init_args'])
		if 'contrast_loss' in config_f:
			model_configs['contrast_loss'] = dict(module_name='ContrastiveLoss',init_args=config_f['contrast_loss']['init_args'])
		if 'optimizer' in config_f:
			optim_config = config_f['optimizer']
		if 'scheduler' in config_f:
			scheduler_config = config_f['scheduler']
		if 'loss_weights' in config_f:
			loss_weights = config_f['loss_weights']
		if 'batch_size' in config_f:
			batch_size = config_f['batch_size']
		if 'num_epochs' in config_f:
			num_epochs = config_f['num_epochs']
		extra_kwargs.update(config_f)
	if 'quantizer' in model_configs:
		if 'encoder' in model_configs:
			# Match the codebook dimensionality.
			model_configs['encoder']['init_args']['z_channels'] = model_configs['quantizer']['init_args']['channels']
			model_configs['decoder']['init_args']['z_channels'] = model_configs['quantizer']['init_args']['channels']
		if 'wav2vec2' in model_configs:
			model_configs['wav2vec2']['init_args']['out_channels'] = model_configs['quantizer']['init_args']['channels']
	scheduler_config['initial_lr'] = optim_config['lr']
	if model_configs:
		for config in model_configs.values():
			config['optim_config'] = config.pop('optimizer', optim_config)
			config['scheduler_config'] = config.pop('scheduler', scheduler_config)
		return model_configs,loss_weights,batch_size,num_epochs,extra_kwargs
	else:
		return optim_config,scheduler_config,loss_weights,batch_size,num_epochs,extra_kwargs