
from copy import deepcopy
import torchvision.models as models
import torch.nn.functional as F
import torch

from Variables import *

class DropBlock(torch.nn.Module):
    def __init__(self, block_size: int, p: float = 0.5, gamma_fac: float = 1):
        super().__init__()
        self.block_size = block_size
        self.p = p
        self.gamma_fac = gamma_fac

    def calculate_gamma(self, x: torch.Tensor) -> float:
        """Compute gamma, eq (1) in the paper
        Args:
            x (Tensor): Input tensor
        Returns:
            Tensor: gamma
        """
        
        invalid = (1 - self.p) / (self.block_size ** 2)
        valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2)
        return (invalid * valid)*self.gamma_fac

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            gamma = self.calculate_gamma(x)
            mask = torch.bernoulli(torch.ones_like(x) * gamma)
            mask_block = 1 - F.max_pool2d(
                mask,
                kernel_size=(self.block_size, self.block_size),
                stride=(1, 1),
                padding=(self.block_size // 2, self.block_size // 2),
            )
            x = mask_block * x * (mask_block.numel() / mask_block.sum())
        return x
    
class Classifier_ResNet50(torch.nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet50(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
        model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # for resnet50 use pretrained weights from Conrad, Ryan, and Kedar Narayan. "CEM500K, a large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning." Elife 10 (2021): e65894.
        state = torch.load(EM_PRETRAINED_WEIGHTS, map_location='cpu')
        state_dict = state['state_dict']
        #format the parameter names to match torchvision resnet50
        resnet50_state_dict = deepcopy(state_dict)
        for k in list(resnet50_state_dict.keys()):
            #only keep query encoder parameters; discard the fc projection head
            if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                resnet50_state_dict[k[len("module.encoder_q."):]] = resnet50_state_dict[k]
            #delete renamed or unused k
            del resnet50_state_dict[k]
        # load model weights
        model.load_state_dict(resnet50_state_dict, strict=False)
        self.model = model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class Classifier_ResNet50_DropBlock(torch.nn.Module):
    def __init__(self, block_size: int = 3, p: float = 0.9, final_act: torch.nn = torch.nn.Identity(), output_neurons: int = OUTPUT_NEURONS, load_em_weights: bool = True):
        super().__init__()

        self.register_buffer("block_size", torch.tensor(block_size))
        self.register_buffer("p", torch.tensor(p))
        self.register_buffer("output_neurons", torch.tensor(output_neurons))

        model = models.resnet50(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, self.output_neurons)
        model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if(load_em_weights):
            state = torch.load(EM_PRETRAINED_WEIGHTS, map_location='cpu')
            state_dict = state['state_dict']
            #format the parameter names to match torchvision resnet50
            resnet50_state_dict = deepcopy(state_dict)
            for k in list(resnet50_state_dict.keys()):
                #only keep query encoder parameters; discard the fc projection head
                if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                    resnet50_state_dict[k[len("module.encoder_q."):]] = resnet50_state_dict[k]

                #delete renamed or unused k
                del resnet50_state_dict[k]
            # load model weights
            model.load_state_dict(resnet50_state_dict, strict=False)

        conv1 = model.conv1
        bn1 = model.bn1
        relu = model.relu
        maxpool = model.maxpool
        layer1 = model.layer1
        layer2 = model.layer2
        layer3 = model.layer3
        layer4 = model.layer4
        avgpool = model.avgpool
        fc = model.fc

        dropblock3 = DropBlock(block_size=self.block_size, p=self.p, gamma_fac=1/4)
        dropblock4 = DropBlock(block_size=self.block_size, p=self.p)
        self.model1 = torch.nn.Sequential(conv1, bn1, relu, maxpool, layer1, layer2, layer3, dropblock3, layer4, dropblock4, avgpool)
        self.model2 = torch.nn.Sequential(fc, final_act)  

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out_x = self.model1(x)
        out_x = torch.flatten(out_x, 1)
        return self.model2(out_x)
    

class Classifier_ResNet101(torch.nn.Module):
    def __init__(self):
        super().__init__()
        model = models.resnet101(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
        self.model = model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
class Classifier_ResNet101_DropBlock(torch.nn.Module):
    def __init__(self, block_size: int = 3, p: float = 0.9, final_act: torch.nn = torch.nn.Identity(), output_neurons: int = OUTPUT_NEURONS):
        super().__init__()
        self.register_buffer("block_size", torch.tensor(block_size))
        self.register_buffer("p", torch.tensor(p))
        self.register_buffer("output_neurons", torch.tensor(output_neurons))

        model = models.resnet101(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)

        conv1 = model.conv1
        bn1 = model.bn1
        relu = model.relu
        maxpool = model.maxpool
        layer1 = model.layer1
        layer2 = model.layer2
        layer3 = model.layer3
        layer4 = model.layer4
        avgpool = model.avgpool
        fc = model.fc

        dropblock3 = DropBlock(block_size=self.block_size, p=self.p, gamma_fac=1/4)
        dropblock4 = DropBlock(block_size=self.block_size, p=self.p)
        self.model1 = torch.nn.Sequential(conv1, bn1, relu, maxpool, layer1, layer2, layer3, dropblock3, layer4, dropblock4, avgpool)
        self.model2 = torch.nn.Sequential(fc, final_act)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out_x = self.model1(x)
        out_x = torch.flatten(out_x, 1)
        return self.model2(out_x)
    
    

class MagnificationModel_Head(torch.nn.Module):
    def __init__(self, backbone, dropblock, final_act = torch.nn.Identity()):
        super().__init__()
        if(backbone == "resnet50"):
            if(dropblock): 
                model = Classifier_ResNet50_DropBlock()
                self.num_ftrs = model.model2[0].in_features
                model.model2[0] = torch.nn.Identity()
            else: 
                model = Classifier_ResNet50()
                self.num_ftrs = model.model.fc.in_features
                # model.model.avgpool = torch.nn.Identity()
                model.model.fc = torch.nn.Identity()
        elif(backbone == "resnet101"):
            if(dropblock): 
                model = Classifier_ResNet101_DropBlock()
                self.num_ftrs = model.model2[0].in_features
                model.model2[0] = torch.nn.Identity()

            else: 
                model = Classifier_ResNet101()
                self.num_ftrs = model.model.fc.in_features
                # model.model.avgpool = torch.nn.Identity()
                model.model.fc = torch.nn.Identity()

        self.model = model
        self.magnification_head = torch.nn.Sequential(*[torch.nn.Linear(self.num_ftrs*2, self.num_ftrs), torch.nn.ReLU(), torch.nn.Linear(self.num_ftrs, OUTPUT_NEURONS), final_act])
    
    def forward(self, x) -> torch.Tensor:
        img = x[0]
        virus_size = x[1]
        img_out = self.model(img)
        virus_size = virus_size[:,None].repeat(1,self.num_ftrs)
        head_input = torch.concat((virus_size, img_out), dim = -1)
        return self.magnification_head(head_input)
    

class MagnificationModel_Embeddings(torch.nn.Module):
    def __init__(self, backbone, embedding_dim = 50):
        super().__init__()

        if(backbone == "resnet50"):
            model = models.resnet50(pretrained=True)
            num_ftrs = model.fc.in_features
            model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
            model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
            # for resnet50 use pretrained weights from Conrad, Ryan, and Kedar Narayan. "CEM500K, a large-scale heterogeneous unlabeled cellular electron microscopy image dataset for deep learning." Elife 10 (2021): e65894.
            state_path = EM_PRETRAINED_WEIGHTS# './../pretrained_models/cem500k_mocov2_resnet50_200ep_pth.tar'
            state = torch.load(state_path, map_location='cpu')
            state_dict = state['state_dict']
            #format the parameter names to match torchvision resnet50
            resnet50_state_dict = deepcopy(state_dict)
            for k in list(resnet50_state_dict.keys()):
                #only keep query encoder parameters; discard the fc projection head
                if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                    resnet50_state_dict[k[len("module.encoder_q."):]] = resnet50_state_dict[k]
                #delete renamed or unused k
                del resnet50_state_dict[k]
            # load model weights
            model.load_state_dict(resnet50_state_dict, strict=False)
        
        
        elif(backbone == "resnet101"):
            model = models.resnet101(pretrained=True)
            num_ftrs = model.fc.in_features
            model.fc = torch.nn.Linear(num_ftrs, OUTPUT_NEURONS)
            model.conv1 = torch.nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)

            n_classes = 500 # the capsid sizes range up to 478 for the herpes data set and 110 in the covid data set
            self.num_channels = 3
            self.magnification_embedding = torch.nn.Sequential(torch.nn.Embedding(n_classes, embedding_dim), 
                                                               torch.nn.Linear(embedding_dim, self.num_channels*IMG_SIZE[0]*IMG_SIZE[1]))
        
        
        self.model = model
        self.num_ftrs = num_ftrs
        

    def forward(self, x) -> torch.Tensor:
        img = x[0]
        virus_size = x[1].int()
        magnification_embedding = self.magnification_embedding(virus_size)
        magnification_embedding = magnification_embedding.view(-1, self.num_channels, IMG_SIZE[0], IMG_SIZE[1]) # reshape to image input
        concat_input = torch.cat((img, magnification_embedding), dim = 1)
        return self.model(concat_input)
        


