from typing import Callable, Optional
from torch.nn.modules import Module, Sequential
from torch import Tensor
from einops import rearrange
from typing import Type, Union,List
import inspect
import math
import torch
import torch.nn as nn   
from torchvision.models.resnet import BasicBlock, Bottleneck,ResNet
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Timedependent_Wrapper(nn.Sequential):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self,x,t):
        for model in self.model:
            if(type(model)==BasicBlock_Time or type(model)==ResNet_Time):
                x=model(x,t)
            else:
                x=model(x)
        return x
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        #self.type_test=torch.nn.Parameter(torch.zeros(1)).to(torch.float32)
    def forward(self, x):
        device = x.device
        #x_type=self.type_test.dtype
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        # if(x_type==torch.float16):
        #     return emb.half()
        return emb
class BasicBlock_Time(BasicBlock):
    def __init__(self, 
                 inplanes: int, 
                 planes: int, 
                 stride: int = 1, 
                 downsample: Optional[nn.Module] = None, 
                 groups: int = 1, 
                 base_width: int = 64, 
                 dilation: int = 1, 
                 norm_layer: Callable[..., Module] =None,
                 time_emb_dim: int = 256,
                 ) -> None:
        super(BasicBlock_Time,self).__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
        self.mlp=nn.Sequential(
            nn.GroupNorm(groups, time_emb_dim, eps = 1e-3),
            nn.SiLU(),
            nn.Linear(time_emb_dim, planes * 2)
        )
    def forward(self, x: Tensor,t=None) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        if(t is not None):
            #print("process t: ",t.shape)
            time_emb = self.mlp(t)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale,shift = time_emb.chunk(2, dim = 1)
            out = out * (1 + scale) + shift
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
class ResNet_Time(ResNet):
    def __init__(self, 
                 block: Type[Union[BasicBlock, Bottleneck]], 
                 layers: List[int], 
                 num_classes: int = 1000, 
                 zero_init_residual: bool = False, 
                 groups: int = 1, 
                 width_per_group: int = 64, 
                 replace_stride_with_dilation: Optional[List[bool]] = None, 
                 norm_layer: Optional[Callable[..., Module]] = None, 
                 time_emb_dim: int = 256,
                 ) -> None:
        super(ResNet_Time,self).__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
        #print(self.inplanes,self.conv1)
        self.time_emb_dim = time_emb_dim
        dim=time_emb_dim//4
        self.sinu_pos_emb = SinusoidalPosEmb(dim)
        fourier_dim = time_emb_dim//4

        self.time_mlp = nn.Sequential(
            self.sinu_pos_emb,
            nn.Linear(fourier_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        self.time_mlp_layer1=nn.Sequential(
            nn.GroupNorm(groups, time_emb_dim, eps = 1e-3),
            nn.SiLU(),
            nn.Linear(time_emb_dim, 64 * 2)
        )
    def _forward_time(self,layer, x: Tensor,t=None) -> Tensor:
        for model in layer:
            if(type(model)==BasicBlock_Time or type(model)==ResNet_Time):
                x=model(x,t)
            else:
                x=model(x)
        return x
    def _forward_impl(self, x: Tensor,t=None) -> Tensor:
        if(t is not None):
            t = self.time_mlp(t)
        x = self.conv1(x)
        x = self.bn1(x)
        if(t is not None):
            #print("process t: ",t.shape)
            time_emb = self.time_mlp_layer1(t)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale,shift = time_emb.chunk(2, dim = 1)
            #print(x.shape,time_emb.shape,scale.shape,shift.shape)
            x = x * (1 + scale) + shift
        x = self.relu(x)
        x = self.maxpool(x)
        # x = self.layer1(x,t)
        # x = self.layer2(x,t)
        # x = self.layer3(x,t)
        # x = self.layer4(x,t)
        x = self._forward_time(self.layer1,x,t)
        x=self._forward_time(self.layer2,x,t)
        x=self._forward_time(self.layer3,x,t)
        x=self._forward_time(self.layer4,x,t)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
    def forward(self, x: Tensor,t=None) -> Tensor:
        return self._forward_impl(x,t)
    def load_state_dict_from_original(self, static_dict,strict=True):
        #print(state_dict.keys())
        error_keys=[]
        for keys in self.state_dict().keys():
            if(keys.find("layer")!=-1 and keys.find("model")!=-1):
                try:
                    #print(keys,keys.replace("model.",""))
                    self.state_dict()[keys]=static_dict[keys.replace("model.","")]
                except:
                    error_keys.append(keys)
            elif(keys in static_dict.keys()):
                self.state_dict()[keys]=static_dict[keys]
            else:
                error_keys.append(keys)
        if(strict):
            #thorw error if the keys are not the same
            if(len(error_keys)!=0):
                raise KeyError(error_keys)
        else:
            print("The following keys are not loaded")
            print(error_keys)
            