from .unet import UNetModel

NUM_CLASSES = 1000

AFHQ_DICT = dict(
    attention_resolutions="16",
    class_cond=False,
    dropout=0.0,
    image_size=256,
    learn_sigma=True,
    num_channels=128,
    num_head_channels=64,
    num_res_blocks=1,
    resblock_updown=True,
    use_fp16=False,
    use_scale_shift_norm=True,
    num_heads=4,
    num_heads_upsample=-1,
    channel_mult="",
    use_checkpoint=False,
    use_new_attention_order=False,
)


IMAGENET_DICT = dict(
    attention_resolutions="32,16,8",
    class_cond=True,
    image_size=256,
    learn_sigma=True,
    num_channels=256,
    num_head_channels=64,
    num_res_blocks=2,
    resblock_updown=True,
    use_fp16=False,
    use_scale_shift_norm=True,
    dropout=0.0,
    num_heads=4,
    num_heads_upsample=-1,
    channel_mult="",
    use_checkpoint=False,
    use_new_attention_order=False,
)


def create_model(
    image_size,
    num_channels,
    num_res_blocks,
    channel_mult="",
    learn_sigma=False,
    class_cond=False,
    use_checkpoint=False,
    attention_resolutions="16",
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    dropout=0,
    resblock_updown=False,
    use_fp16=False,
    use_new_attention_order=False,
):
    if channel_mult == "":
        if image_size == 512:
            channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
        elif image_size == 256:
            channel_mult = (1, 1, 2, 2, 4, 4)
        elif image_size == 128:
            channel_mult = (1, 1, 2, 3, 4)
        elif image_size == 64:
            channel_mult = (1, 2, 3, 4)
        else:
            raise ValueError(f"unsupported image size: {image_size}")
    else:
        channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))

    attention_ds = []
    for res in attention_resolutions.split(","):
        attention_ds.append(image_size // int(res))

    return UNetModel(
        image_size=image_size,
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3 if not learn_sigma else 6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=use_checkpoint,
        use_fp16=use_fp16,
        num_heads=num_heads,
        num_head_channels=num_head_channels,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        resblock_updown=resblock_updown,
        use_new_attention_order=use_new_attention_order,
    )


def i_DDPM(dataset_name = 'AFHQ'):
    if dataset_name in  ['AFHQ', 'FFHQ']:
        return create_model(**AFHQ_DICT)
    elif dataset_name == 'IMAGENET':
        return create_model(**IMAGENET_DICT)
    else:
        print('Not implemented.')
        exit()
