"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

from lavis.models.base_model import BaseEncoder
from lavis.models.beats.BEATs import BEATs, BEATsConfig
import torch 
from lavis.common.utils import is_url
from lavis.common.dist_utils import download_cached_file
import os 


# ckp_path =  "https://valle.blob.core.windows.net/share/BEATs/BEATs_iter3_plus_AS2M.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"

ckp_path = 'BEATs_iter3_plus_AS2M.pt'

class BeatsEncoder(BaseEncoder):
    def __init__(self, checkpoint_path=ckp_path):
        super().__init__()
        
        # load the pre-trained checkpoints
        if is_url(checkpoint_path):
            cached_file = download_cached_file(
                checkpoint_path, check_hash=False, progress=True
            )
            checkpoint = torch.load(cached_file)
        elif os.path.isfile(checkpoint_path):
            print('loading checkpoint for BEATs Encoder')
            checkpoint = torch.load(checkpoint_path)

        cfg = BEATsConfig(checkpoint['cfg'])
        self.num_features = cfg.encoder_embed_dim
        self.model = BEATs(cfg)
        self.model.load_state_dict(checkpoint['model'])
        self.model.eval()

    @classmethod
    def from_config(cls, cfg):
        checkpoint_path = cfg.get("checkpoint_path", ckp_path)
        return cls(checkpoint_path)

    def forward(self, x):
        with torch.no_grad():
            return self.model.extract_features(x.squeeze(1))[0]