import torch
import timm
import torch.nn as nn
import torchvision.transforms as transforms

class TimmModel(nn.Module):
    def __init__(self, model_name, do_inverse_normalize=True, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        super().__init__()
        self.do_inverse_normalize = do_inverse_normalize
        if self.do_inverse_normalize:
            n_channels = len(mean)
            self.mean = torch.tensor(mean).reshape(1, n_channels, 1, 1).cuda()
            self.std = torch.tensor(std).reshape(1, n_channels, 1, 1).cuda()
        self.model = timm.create_model(model_name, pretrained=True)
        data_config = timm.data.resolve_model_data_config(self.model)
        if data_config['interpolation'] == 'bilinear':
            interpolation = transforms.InterpolationMode.BILINEAR
        else:
            interpolation = transforms.InterpolationMode.BICUBIC
        self.preprocess = transforms.Compose([
            transforms.Resize(size=data_config['input_size'][1:], interpolation=interpolation),
            transforms.Normalize(mean=data_config['mean'], std=data_config['std']),
        ])

    def forward(self, x):
        if self.do_inverse_normalize:
            x = x * self.std + self.mean
        x = self.preprocess(x)
        return self.model(x)