from timm import create_model 
from .train_util import ExponentialMovingAverage
import torch 
import torchvision
from torchvision import transforms 

class SwinClassifiers(torch.nn.Module):
    def __init__(self, model_type):
        super().__init__()
        
        def new_model(in_chans):
            return (torchvision.models.swin_v2_t(weights='IMAGENET1K_V1'))
        
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        
        self.preprocess = transforms.Compose([
                                              transforms.Normalize([-1,-1,-1],[2,2,2]),
                                              transforms.Normalize(mean, std)])
    
        self.clf = new_model(3)
        for p in self.clf.parameters():
            p.requires_grad = False

        self.clf_sim = new_model(3)
        self.ema_clf_sim = ExponentialMovingAverage(self.clf_sim.parameters(), decay=0.999, model=new_model(3))
    
    def forward(self, clean, noisy, denoised, ema_model=False):
        if ema_model:
            return self.ema_forward(clean, noisy, denoised)
        
        clean = self.preprocess(clean)
        denoised = self.preprocess(denoised)
        self.clf.eval()
        if self.training:
            clean_logits = self.clf_sim(clean)
        else:
            clean_logits = self.clf(clean)
        
        sim_logits = self.clf_sim(torch.cat([denoised],axis=1))
        
        return clean_logits, sim_logits
    
    def ema_forward(self, clean, noisy, denoised):
        clean = self.preprocess(clean)
        denoised = self.preprocess(denoised)
        
        self.clf.eval()
        clean_logits = self.clf(clean)
        
        sim_logits = self.ema_clf_sim.get_model().cuda().eval()(torch.cat([denoised],axis=1))

        return clean_logits, sim_logits 
    
    def update_ema(self):
        self.ema_clf_sim.update(self.clf_sim.parameters())

    def state_dict(self):
        sd = {}
        _m = self._modules
        for k in _m:
            sd[k] = _m[k].state_dict()
        for k in self.__dict__:
            if k.startswith('ema'):
                sd[k] = self.__dict__[k].state_dict()
        return sd 
    
    def load_state_dict(self, sd):
        for k in self._modules:
            self._modules[k].load_state_dict(sd[k])
        for k in sd:
            if k.startswith('ema'):
                self.__dict__[k].load_state_dict(sd[k])
