import torch 
from vit_models import deit_base_patch16_LS
from torchvision import models 
import os
from torch import nn

class Nets(torch.nn.Module):
   def __init__(self, paths, indices, RANK=0):
      super().__init__()
      def make_net(key):
         if key[:3] == 'VIT':
            net = deit_base_patch16_LS().cuda(RANK).eval()
            if os.path.exists(paths[key]):
               ckpt = torch.load(paths[key])
               net.load_state_dict(ckpt['model'])
               print("Loaded finetuned DeIT")
            else:
               ckpt = torch.load("workdirs/deit_3_base_224_1k.pth")
               net.load_state_dict(ckpt['model'])
               print("Loaded default DeIT")
         else:
            net = torch.nn.DataParallel(models.resnet50(weights='IMAGENET1K_V1').cuda(RANK),device_ids=[RANK]).eval()
            if os.path.exists(paths[key]):
               ckpt = torch.load(paths[key])
               if 'state_dict' in ckpt:
                  net.load_state_dict(ckpt['state_dict'])
               num_steps = ckpt.get("steps",0)
               print(f"Loaded {key} model finetuned for {num_steps} steps")
            else:
               print("Loaded default RN50 model")
         return net
      
      self.nets = nn.ModuleDict()
      for key in paths:
         self.nets[key] = make_net(key)
      self.indices = indices
    
   def forward(self, images):
      out = dict()
      for key in self.nets:
         if key[:3] != 'VIT':
            out[key] = self.nets[key](images).cpu()
         else:
            with torch.cuda.amp.autocast(enabled=True):
               out[key] = self.nets[key](images)
            out[key] = out[key].float().cpu()
         if self.indices is not None:
            out[key] = out[key][:,self.indices]
      return out
   
def get_nets(indices=None):
   orig_net = Nets({
       "DA": "workdirs/deepaugment.pth.tar",
       "AM": "workdirs/checkpoint.pth.tar",
       "DAM": "workdirs/deepaugment_and_augmix.pth.tar",
       "BASE":"NONE",
       "VIT":"NONE",
    },indices)
   
   net = Nets({
       "DA": "workdirs/rn50_dadiff.pt",
       "AM": "workdirs/rn50_diffaug.pt",
       "DAM": "workdirs/rn50_diffdaaug.pt",
       "BASE":"workdirs/rn50_diffbase_scratch.pt",
       "VIT":"workdirs/diffvit_extra_v0/checkpoint.pth",
    },indices)
   
   return orig_net, net