from timm import create_model 
from .train_util import ExponentialMovingAverage
import torch 
from torchvision import transforms 

class NFClassifiers(torch.nn.Module):
    def __init__(self, model_type):
        super().__init__()

        def new_model(in_chans):
            return create_model(
                model_name=model_type,
                pretrained=False,
                num_classes=1000,
                act_layer=torch.nn.SiLU,
                drop_rate=0.25,
                in_chans=in_chans,
                drop_path_rate=0.1,
            )
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        
        self.preprocess = transforms.Normalize(mean, std)
        
        self.clf_noisy = new_model(3)
        self.ema_clf_noisy = ExponentialMovingAverage(self.clf_noisy.parameters(), decay=0.999, model=new_model(3))

        self.clf = new_model(3)
        self.ema_clf = ExponentialMovingAverage(self.clf.parameters(), decay=0.999, model=new_model(3))

        self.clf_sim = new_model(3)
        self.ema_clf_sim = ExponentialMovingAverage(self.clf_sim.parameters(), decay=0.999, model=new_model(3))
    
    def forward(self, clean, noisy, denoised, ema_model=False):
        if ema_model:
            return self.ema_forward(clean, noisy, denoised)
        
        clean = self.preprocess(clean)
        noisy = self.preprocess(noisy)
        denoised = self.preprocess(denoised)

        clean_logits = self.clf(clean)
        noisy_logits = self.clf_noisy(noisy)
        sim_logits = self.clf_sim(torch.cat([denoised],axis=1))
        
        return clean_logits, noisy_logits, sim_logits
    
    def ema_forward(self, clean, noisy, denoised):
        clean = self.preprocess(clean)
        noisy = self.preprocess(noisy)
        denoised = self.preprocess(denoised)
        
        clean_logits = self.ema_clf.get_model()(clean)
        noisy_logits = self.ema_clf_noisy.get_model()(noisy)
        sim_logits = self.ema_clf_sim.get_model()(torch.cat([denoised],axis=1))

        return clean_logits, noisy_logits, sim_logits 
    
    def update_ema(self):
        self.ema_clf.update(self.clf.parameters())
        self.ema_clf_noisy.update(self.clf_noisy.parameters())
        self.ema_clf_sim.update(self.clf_sim.parameters())

    def state_dict(self):
        sd = {}
        _m = self._modules
        for k in _m:
            sd[k] = _m[k].state_dict()
        for k in self.__dict__:
            if k.startswith('ema'):
                sd[k] = self.__dict__[k].state_dict()
        return sd 
    
    def load_state_dict(self, sd):
        for k in self._modules:
            self._modules[k].load_state_dict(sd[k])
        for k in sd:
            if k.startswith('ema'):
                self.__dict__[k].load_state_dict(sd[k])
