import mia.conversion as conversion
import data.synthetic as synthetic
from PIL import Image
from torchvision import transforms
import yaml
import torch.nn as nn
from pathlib import Path

def get_datasets_and_converter(args):
    converter = conversion.Converter('synthetic')
    converter = { mode : converter for mode in args.sample_modes }

    train_dataset = synthetic.Synthetic(
        root = args.dataset_config.path,
        train = True,
        modes = args.sample_modes,
        ratio = args.num_data_ratio,
        use_jitter = args.dataset_config['use_jitter'],
    )

    if args.use_meta_test_set:
        test_dataset = synthetic.SyntheticMeta(
            root = args.dataset_config.path,
            modes = args.sample_modes,
            ratio = 1.0,
            use_jitter = args.dataset_config['use_jitter'],
            repeat_factor = args.validation_repeat,
        )
    else:
        test_dataset = synthetic.Synthetic(
            root = args.dataset_config.path,
            train = False,
            modes = args.sample_modes,
            ratio = 1.0,
            use_jitter = args.dataset_config['use_jitter'],
            repeat_factor = args.validation_repeat,
        )

    print(f'Num train data: {len(train_dataset)}')
    print(f'Num valid data: {len(test_dataset)}')
    return train_dataset, test_dataset, converter


def get_model(args):
    if args.model_type == 'multimodal':
        from mia.models.multimodal_models import CrossModalTransferModel
        return CrossModalTransferModel(
            args = args,
            modes = args.modes,
            latent_spatial_shapes = args.latent_spatial_shapes,

            inr_dict = {
                'dim_hidden' : args.dim_hidden,
                'num_layers' : args.num_layers,
                'inr_type' : args.inr_type,
                'ff_dim' : args.ff_dim,
                'sigma' : args.sigma,
            },

            context_encoder_dict = {
                'type' : args.context_encoder_type,
                'dim' : args.context_encoder_dim,
                'num_querys' : args.context_encoder_num_querys,
                'um_depth' : args.context_encoder_um_depth,
                'mm_depth' : args.context_encoder_mm_depth,
                'heads' : args.context_encoder_heads,
                'dim_head' : args.context_encoder_dim // args.context_encoder_heads,
                'mlp_ratio' : args.context_encoder_mlp_ratio,
                'dropout' : args.context_encoder_dropout,
                'embed_dropout' : args.context_encoder_embed_dropout,
                'pos_embed_type' : args.context_encoder_pos_embed_type,
                'pooler_pos_embed_type' : args.context_pooler_pos_embed_type,
                'pooler_depth' : args.context_pooler_depth,
                'topk' : args.context_encoder_topk,
            },

            grad_encoder_dict = {
                'type' : args.grad_encoder_type,
                'dim' : args.grad_encoder_dim,
                'um_depth' : args.grad_encoder_um_depth,
                'mm_depth' : args.grad_encoder_mm_depth,
                'heads' : args.grad_encoder_heads,
                'dim_head' : args.grad_encoder_dim // args.grad_encoder_heads,
                'mlp_ratio' : args.grad_encoder_mlp_ratio,
                'dropout' : args.grad_encoder_dropout,
                'pos_embed_type' : args.grad_encoder_pos_embed_type,
                'mm_attn_type' : args.grad_encoder_mm_attn_type,
                'use_latent' : args.grad_encoder_use_latent,
                'use_alfa' : args.grad_encoder_use_alfa,
                'dim_alfa' : args.grad_encoder_dim_alfa,
                'depth_alfa' : args.grad_encoder_depth_alfa,
                #
                'use_fuser': args.grad_encoder_use_fuser,
                'depth_fuser': args.grad_encoder_depth_fuser,
                #
                'projection_mlp_depth': args.grad_encoder_projection_mlp_depth,
            },

            meta_sgd_dict = {
                'use_meta_sgd': args.use_meta_sgd,
                'inner_lr_init': args.meta_sgd_lr_init,
            },
        ).to(args.device)
