

import numpy as np
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from collections import OrderedDict
import functools

from . import loss
from .openai_model import Transformer

def get_loss(model, args, cfg, tokenizer=None):
    return loss.CLIPLoss(
        cache_labels=True,
        rank=args.rank,
        world_size=args.world_size,
    )
    
def get_metric_names(cfg):
    return ['loss', 'clip_loss', 'clip_acc']
    

def inflate_positional_embeds(
    current_model_state_dict, new_state_dict,
    num_frames=4,
    load_temporal_fix='bilinear',
):
    # allow loading of timesformer with fewer num_frames
    curr_keys = list(current_model_state_dict.keys())
    if 'visual.temporal_embed' in new_state_dict and 'visual.temporal_embed' in curr_keys:
        load_temporal_embed = new_state_dict['visual.temporal_embed']
        load_num_frames = load_temporal_embed.shape[1]
        curr_num_frames = num_frames
        embed_dim = load_temporal_embed.shape[2]

        if load_num_frames != curr_num_frames:
            if load_num_frames > curr_num_frames:
                print(f'### loaded SpaceTimeTransformer model has MORE frames than current...'
                      f'### loading weights, filling in the extras via {load_temporal_fix}')
                new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :]
            else:
                print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...'
                      f'### loading weights, filling in the extras via {load_temporal_fix}')
                if load_temporal_fix == 'zeros':
                    new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim])
                    new_temporal_embed[:, :load_num_frames] = load_temporal_embed
                elif load_temporal_fix in ['interp', 'bilinear']:
                    # interpolate
                    # unsqueeze so pytorch thinks its an image
                    mode = 'nearest'
                    if load_temporal_fix == 'bilinear':
                        mode = 'bilinear'
                    load_temporal_embed = load_temporal_embed.unsqueeze(0)
                    new_temporal_embed = F.interpolate(load_temporal_embed,
                                                       (curr_num_frames, embed_dim), mode=mode).squeeze(0)
                else:
                    raise NotImplementedError
            new_state_dict['visual.temporal_embed'] = new_temporal_embed
    # allow loading with smaller spatial patches. assumes custom border crop, to append the
    # border patches to the input sequence
    if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys:
        load_pos_embed = new_state_dict['visual.pos_embed']
        load_num_patches = load_pos_embed.shape[1]
        curr_pos_embed = current_model_state_dict['visual.pos_embed']
        if load_num_patches != curr_pos_embed.shape[1]:
            raise NotImplementedError(
                'Loading models with different spatial resolution / patch number not yet implemented, sorry.')

    return new_state_dict


def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)


def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))


# util functions to convert CLIP-style model keys to TimeSformer-style
def remap_keys(clip_state_dict, transformer_layers=12):
    remapped_state_dict = OrderedDict()
    key_mapping = {
        "class_embedding": "cls_token",
        "positional_embedding": "pos_embed",
        "conv1.weight": "patch_embed.proj.weight",
        "ln_pre.weight": "ln_pre.weight",
        "ln_pre.bias": "ln_pre.bias",
        "ln_post.weight": "norm.weight",
        "ln_post.bias": "norm.bias",
    }
    for layer in range(transformer_layers):
        key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight"
        key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias"
        key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight"
        key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias"
        key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight"
        key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias"
        key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight"
        key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias"
        key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight"
        key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias"
        key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight"
        key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias"

    for key in clip_state_dict:
        if key == 'proj':
            continue  # due to possible dim mismatch, we load this later
        if key == 'temporal_embedding' and key not in key_mapping:
            continue  # due to temporal embedding is added to CoCa before attention pooling
        if key == "class_embedding":
            clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0)
        if key == "positional_embedding":
            clip_state_dict[key] = clip_state_dict[key].unsqueeze(0)
        remapped_state_dict[key_mapping[key]] = clip_state_dict[key]

    return remapped_state_dict


