# Based on implementations from the 4M repo: https://github.com/apple/ml-4m/
from functools import partial

import fourm.utils.data_constants as data_constants
from fourm.data.modality_transforms import (CaptionTransform, DepthTransform,
                                      DetectionTransform, MaskTransform,
                                      NormalTransform, RGBTransform,
                                      SemsegTransform, TokTransform,
                                      CaptionEmbTransform, MetadataTransform,
                                      HumanPoseTransform, ColorPaletteTransform,
                                      SAMInstanceTokTransform, SAMInstanceTransform)
from fourm.models.decoder_embeddings import (ImageTokenDecoderEmbedding,
                                   SequenceDecoderEmbedding)
from fourm.models.encoder_embeddings import (ImageEncoderEmbedding,
                                   ImageTokenEncoderEmbedding,
                                   SequenceEncoderEmbedding,
                                   SequenceEmbEncoderEmbedding)
from fourm.utils import generate_uint15_hash

MODALITY_INFO = {
    # 4M-7 modalities
    'rgb@224': {
        'input_size': 224,
        'patch_size': 16,
        'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3),
        'decoder_embedding': None,
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'num_channels': 3,
        'id': generate_uint15_hash('rgb@224'),
        'path': 'rgb',
    },
    'rgb': { # used for tokenizer training
        'type': 'img',
        'num_channels': 3,
        'id': generate_uint15_hash('rgb'),
        'path': 'rgb',
    },
    'caption': {
        'vocab_size': 30_000,
        'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
        'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
        'min_tokens': 0,
        'max_tokens': 256,
        'type': 'seq',
        'id': generate_uint15_hash('caption'),
    },
    'det': { 
        'vocab_size': 30_000,
        'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
        'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=256, padding_idx=0),
        'min_tokens': 0,
        'max_tokens': 256,
        'type': 'seq',
        'id': generate_uint15_hash('det'),
    },
    'tok_rgb@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 16384,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_rgb@224'),
        'pretokenized': True,
    },
    'tok_depth@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_depth@224'),
        'pretokenized': True,
    },
    'depth': { # used for tokenizer training
        'type': 'img',
        'num_channels': 1,
        'id': generate_uint15_hash('depth'),
    },
    'tok_normal@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_normal@224'),
        'pretokenized': True,
    },
    'normal': { # used for tokenizer training
        'type': 'img',
        'num_channels': 3,
        'id': generate_uint15_hash('normal'),
    },
    'tok_semseg@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 4096,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_semseg@224'),
        'pretokenized': True,
    },
    'semseg_coco': { # used for tokenizer training
        'type': 'img', 
        'num_channels': 64,
        'num_labels': data_constants.COCO_SEMSEG_NUM_CLASSES,
        'id': generate_uint15_hash('semseg_coco'),
    },
    'tok_clip@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_clip@224'),
        'pretokenized': True,
    },
    'CLIP-B16': { # used for tokenizer training
        'type': 'feature_map',
        'num_channels': 512,
        'id': generate_uint15_hash('CLIP-B16'),
    },

    # 4M-21 modalities
    't5_caption': {
        'encoder_embedding': partial(SequenceEmbEncoderEmbedding, max_length=77, padding_idx=0),
        'decoder_embedding': None,
        'min_tokens': 0,
        'max_tokens': 77,
        'type': 'seq_emb',
        'id': generate_uint15_hash('t5_caption'),
    },
    'metadata': { 
        'vocab_size': 30_000,
        'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True),
        'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=40, padding_idx=0, sincos_pos_emb=True),
        'min_tokens': 0,
        'max_tokens': 40, # At most 2x19=38 for 19 metadata types, +1 for EOS, +1 for sentinel
        'type': 'seq',
        'id': generate_uint15_hash('metadata'),
        'shared_vocab': ['caption'],
        'path': 'metadata',
    },
    'human_poses': { 
        'vocab_size': 30_000,
        'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True),
        'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=263, padding_idx=0, sincos_pos_emb=True),
        'min_tokens': 0,
        'max_tokens': 275, #7*39+1 EOS+1 S_1#263, #261 in one of the models, or 263 to have EOS #261+1+1 #238,
        'type': 'seq',
        'num_channels': 207, # for tokenization training, only the pose part is needed
        'id': generate_uint15_hash('human_poses'),
        'shared_vocab': ['caption'],
    },
    'color_palette': { 
        'vocab_size': 30_000,
        'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True),
        'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=23, padding_idx=0, sincos_pos_emb=True),
        'min_tokens': 0,
        'max_tokens': 23, #7x3=21 for 7 colors, +1 for EOS, +1 for sentinel
        'type': 'seq',
        'id': generate_uint15_hash('color_palette'),
        'shared_vocab': ['caption'],
        'path': 'color_palette',
    },
    'sam_mask': {
        'encoder_embedding': None,
        'decoder_embedding': None,
        'min_tokens': 0,
        'max_tokens': 64,
        'type': 'img',
        'num_channels': 1,
        'id': generate_uint15_hash('sam_mask'),
    },
    'sam_instance': {
        'vocab_size': 30_000,
        'encoder_embedding': partial(SequenceEncoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True),
        'decoder_embedding': partial(SequenceDecoderEmbedding, vocab_size=30_000, max_length=290, padding_idx=0, sincos_pos_emb=True),
        'min_tokens': 0,
        'max_tokens': 290,
        'type': 'seq',
        'id': generate_uint15_hash('sam_instance'),
        'shared_vocab': ['caption'],
        'pretokenized': True,
    },
    'tok_canny_edge@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_canny_edge@224'),
        'pretokenized': True,
    },
    'canny_edge': { # used for tokenizer training
        'type': 'img',
        'num_channels': 1,
        'id': generate_uint15_hash('canny_edge'),
    },
    'tok_sam_edge@224': {
        'input_size': 224,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 196
        'type': 'img',
        'id': generate_uint15_hash('tok_sam_edge@224'),
        'pretokenized': True,
    },
    'tok_dinov2@224': {
        'input_size': 224,
        'patch_size': 14,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 256
        'type': 'img',
        'id': generate_uint15_hash('tok_dinov2@224'),
        'pretokenized': True,
    },
    'DINOv2-B14': { # used for tokenizer training
        'type': 'feature_map',
        'num_channels': 768,
        'id': generate_uint15_hash('DINOv2-B14'),
    },
    'tok_imagebind@224': {
        'input_size': 224,
        'patch_size': 14,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 256
        'type': 'img',
        'id': generate_uint15_hash('tok_imagebind@224'),
        'pretokenized': True,
    },
    'ImageBind-H14': { # used for tokenizer training
        'type': 'feature_map',
        'num_channels': 1280,
        'id': generate_uint15_hash('ImageBind-H14'),
    },
    'tok_dinov2_global': {
        'vocab_size': 8192,
        'patch_size': 56,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
        'min_tokens': 0,
        'max_tokens': 16,
        'type': 'img',
        'id': generate_uint15_hash('tok_dinov2_global'),
        'pretokenized': True,
    },
    'DINOv2-B14-global': { # used for tokenizer training
        'type': 'feature_map',
        'num_channels': 768,
        'id': generate_uint15_hash('DINOv2-B14-global'),
    },
    'tok_imagebind_global': {
        'vocab_size': 8192,
        'patch_size': 56,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192, sincos_pos_emb=False),
        'min_tokens': 0,
        'max_tokens': 16,
        'type': 'img',
        'id': generate_uint15_hash('tok_imagebind_global'),
        'pretokenized': True,
    },
    'ImageBind-H14-global': { # used for tokenizer training
        'type': 'feature_map',
        'num_channels': 1280,
        'id': generate_uint15_hash('ImageBind-H14-global'),
    },

    ### 224->448 super resolution modalities
    'rgb@448': {
        'input_size': 448,
        'patch_size': 16,
        'encoder_embedding': partial(ImageEncoderEmbedding, num_channels=3),
        'decoder_embedding': None,
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 784
        'type': 'img',
        'num_channels': 3,
        'id': generate_uint15_hash('rgb@448'),
        'path': 'rgb',
    },
    'tok_rgb@448': {
        'input_size': 448,
        'patch_size': 16,
        'vocab_size': 16384,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=16384),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=16384),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 784
        'type': 'img',
        'id': generate_uint15_hash('tok_rgb@448'),
        'pretokenized': True,
    },
    'tok_depth@448': {
        'input_size': 448,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 784
        'type': 'img',
        'id': generate_uint15_hash('tok_depth@448'),
        'pretokenized': True,
    },
    'tok_normal@448': {
        'input_size': 448,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 784
        'type': 'img',
        'id': generate_uint15_hash('tok_normal@448'),
        'pretokenized': True,
    },
    'tok_semseg@448': {
        'input_size': 448,
        'patch_size': 16,
        'vocab_size': 4096,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=4096),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=4096),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 784
        'type': 'img',
        'id': generate_uint15_hash('tok_semseg@448'),
        'pretokenized': True,
    },
    'tok_clip@448': {
        'input_size': 448,
        'patch_size': 16,
        'vocab_size': 8192,
        'encoder_embedding': partial(ImageTokenEncoderEmbedding, vocab_size=8192),
        'decoder_embedding': partial(ImageTokenDecoderEmbedding, vocab_size=8192),
        'min_tokens': 0,
        'max_tokens': None, # Will be set to 784
        'type': 'img',
        'id': generate_uint15_hash('tok_clip@448'),
        'pretokenized': True,
    },
}

# Note: @res suffix is ignored for modality transforms
MODALITY_TRANSFORMS = {
    # 4M-7 modalities
    'rgb': RGBTransform(imagenet_default_mean_and_std=True),
    'caption': CaptionTransform(aligned_captions=True),
    'det': DetectionTransform(det_threshold=0.6, det_max_instances=None, bbox_order='dist_to_orig', coord_bins=1000, min_visibility=0.0),
    'tok_rgb': TokTransform(),
    'tok_depth': TokTransform(),
    'tok_normal': TokTransform(),
    'tok_semseg': TokTransform(),
    'tok_clip': TokTransform(),
    # 4M-21 modalities
    't5_caption': CaptionEmbTransform(),
    'metadata': MetadataTransform(special_vmin=0, special_vmax=999, shuffle=True, random_trunc=False, return_chunks=True),
    'human_poses': HumanPoseTransform(coord_bins=1000),
    'color_palette': ColorPaletteTransform(coord_bins=1000),
    'sam_instance': SAMInstanceTokTransform(image_size=224, points_per_side=7, point_order='random'),
    'tok_canny_edge': TokTransform(),
    'tok_sam_edge': TokTransform(),
    'tok_dinov2': TokTransform(),
    'tok_imagebind': TokTransform(),
    'tok_dinov2_global': TokTransform(),
    'tok_imagebind_global': TokTransform(),
    # Other
    'mask_valid': MaskTransform(mask_pool_size=1),
}

MODALITY_TRANSFORMS_DIVAE = {
    'rgb': RGBTransform(imagenet_default_mean_and_std=False),
    'depth': DepthTransform(standardize_depth=True),
    'normal': NormalTransform(standardize_surface_normals=False),
    'mask_valid': MaskTransform(mask_pool_size=1),
    'semseg_coco': SemsegTransform(shift_idx_by_one=True),
    'canny_edge': RGBTransform(imagenet_default_mean_and_std=False),
    'human_poses': HumanPoseTransform(coord_bins=1000, only_pose=True),
    'sam_mask': SAMInstanceTransform(mask_size=64, max_instance_n=1),
}

MODALITY_TRANSFORMS_VQCONTROLNET = {
    'rgb': RGBTransform(imagenet_default_mean_and_std=False),
    'mask_valid': MaskTransform(mask_pool_size=1),
    'caption': CaptionTransform(aligned_captions=True),
}