import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.transforms import Normalize
import modules.timerewarder
from modules.transformer import OrderPredictor
import os
import re
from vip import load_vip

def get_cost_encoder(name, device, ckpt=None, task_name=None):
    if name == 'resnet':
        encoder = ResNet(base_encoder=models.__dict__['resnet50']).to(device)
        encoder.eval()
    elif name == 'timerewarder' or name =='progressor':
        T = 1 if ('image' in ckpt or 'notext' in ckpt) else 3
        if 'bins' in ckpt:
            match = re.search(r'_(\d+)bins', ckpt)
            if match:
                bin_num = int(match.group(1))
            else:
                bin_num = -1
        else:
            bin_num = -1
        
        if 'progressor' in ckpt:
            progressor = True
        else:
            progressor = False

        encoder = timerewarder.load(model_path=ckpt, T=T, bin_num=bin_num, progressor=progressor)[0]
        encoder = encoder.to(device)
        encoder.eval()
    elif name == 'goalac' or name == 'gt':
        encoder = None
    elif name == 'order':
        encoder = OrderPredictor()
        encoder.load_state_dict(torch.load(ckpt)['model'])
        encoder = encoder.to(device)
        encoder.eval()
    elif name == 'vip':
        if task_name is None:
            raise ValueError("task_name cannot be None for VIP encoder")

        vip_path = ckpt
        encoder = load_vip(modelpath=vip_path, configpath='[path to config_vip]')
        encoder = encoder.to(device)
        encoder.eval()
    else:
        raise NotImplementedError
    return encoder

class ResNet(nn.Module):
    def __init__(self, base_encoder=models.__dict__['resnet50']):
        super(ResNet, self).__init__()
        self.encoder = base_encoder(num_classes=1000, pretrained=True)
        self.img_norm = Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                                    std=torch.tensor([0.229, 0.224, 0.225]))

    def forward(self, obs, spacial=True, normalize=True):
        obs = obs[:,-3:]/255.0
        if normalize:
            obs = self.img_norm(obs)
        if not spacial:
            h = self.encoder(obs)
            h = h.view(obs.shape[0], -1)
            return h
        else:
            h = obs
        i = 0
        for m in list(self.encoder.children()):
            i += 1
            if i <= 8:
                h = m(h)
        h = h.view(obs.shape[0], -1)
        return h
    
