import torch
from typing import Type, Any, Callable, Union, List, Optional
from torch import Tensor
from torch import nn
from torchvision import models

class Residual(nn.Module):
    def __init__(self, input_channels, num_channels,
                 use_1x1conv=True, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = torch.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return torch.relu(Y)

def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

def get_res_model(in_channel ,out_channel):
    b1 = nn.Sequential(nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
    b3 = nn.Sequential(*resnet_block(64, 128, 2))
    b4 = nn.Sequential(*resnet_block(128, 256, 2))
    b5 = nn.Sequential(*resnet_block(256, out_channel, 2))
    net = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten())
    return net

class fix_resnet(models.ResNet):
    
    def __init__(
        self, 
        block: Type[Union[models.resnet.BasicBlock, models.resnet.Bottleneck]] = models.resnet.BasicBlock, 
        layers: List[int] = [2, 2, 2, 2], 
        num_classes: int = 1000, 
        zero_init_residual: bool = False, 
        groups: int = 1, 
        width_per_group: int = 64, 
        replace_stride_with_dilation: Optional[List[bool]] = None, 
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
        self.fc = None
        
    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)
    
    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        #x = self.fc(x)

        return x

def get_resnet18(in_features=3, num_classes=1000, weights=models.ResNet18_Weights.IMAGENET1K_V1):
    resnet = models.resnet18(weights=weights,num_classes=num_classes)
    if in_features==1:
        resnet.conv1 = nn.Conv2d(in_features, 64, kernel_size=7, stride=2, padding=3, bias=False)
    fix = fix_resnet(layers=[2, 2, 2, 2])
    fix.load_state_dict(resnet.state_dict())
    return fix

def get_resnet34(in_features=3, num_classes=1000, weights=models.ResNet34_Weights.IMAGENET1K_V1):
    resnet = models.resnet34(weights=weights,num_classes=num_classes)
    if in_features==1:
        resnet.conv1 = nn.Conv2d(in_features, 64, kernel_size=7, stride=2, padding=3, bias=False)
    fix = fix_resnet(layers=[3, 4, 6, 3])
    fix.load_state_dict(resnet.state_dict())
    return fix

def get_resnet50(in_features=3, num_classes=1000, weights=models.ResNet50_Weights.IMAGENET1K_V1):
    resnet = models.resnet50(weights=weights,num_classes=num_classes)
    if in_features==1:
        resnet.conv1 = nn.Conv2d(in_features, 64, kernel_size=7, stride=2, padding=3, bias=False)
    fix = fix_resnet(block=models.resnet.Bottleneck,layers=[3, 4, 6, 3])
    fix.load_state_dict(resnet.state_dict())
    return fix

def get_resnet101(in_features=3, num_classes=1000, weights=models.ResNet101_Weights.IMAGENET1K_V1):
    resnet = models.resnet101(weights=weights,num_classes=num_classes)
    if in_features==1:
        resnet.conv1 = nn.Conv2d(in_features, 64, kernel_size=7, stride=2, padding=3, bias=False)
    fix = fix_resnet(block=models.resnet.Bottleneck,layers=[3, 4, 23, 3])
    fix.load_state_dict(resnet.state_dict())
    return fix

class fix_vit(models.VisionTransformer):
    
    def __init__(self, image_size: int, patch_size: int, num_layers: int, num_heads: int, hidden_dim: int, mlp_dim: int, dropout: float = 0, attention_dropout: float = 0, num_classes: int = 1000):
        super().__init__(image_size, patch_size, num_layers, num_heads, hidden_dim, mlp_dim, dropout, attention_dropout, num_classes)
        
    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        #x = self.heads(x)

        return x
    
def get_vit_b_16(num_classes=1000,weights=models.ViT_B_16_Weights):
    vit = models.vit_b_16(weights=weights,num_classes=num_classes)
    fix = fix_vit(patch_size=16,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,image_size=224)
    fix.load_state_dict(vit.state_dict())
    return fix

def get_vit_b_32(num_classes=1000,weights=models.ViT_B_32_Weights):
    vit = models.vit_b_32(weights=weights,num_classes=num_classes)
    fix = fix_vit(patch_size=32,
        num_layers=12,
        num_heads=12,
        hidden_dim=768,
        mlp_dim=3072,image_size=224)
    fix.load_state_dict(vit.state_dict())
    return fix

def get_vit_l_16(num_classes=1000,weights=models.ViT_L_16_Weights):
    vit = models.vit_l_16(weights=weights,num_classes=num_classes)
    fix = fix_vit(patch_size=16,
        num_layers=24,
        num_heads=16,
        hidden_dim=1024,
        mlp_dim=4096,image_size=224)
    fix.load_state_dict(vit.state_dict())
    return fix

class fix_net(nn.Module):
    
    def __init__(self, feature_net:nn.Module, in_features, out_features) -> None:
        super().__init__()
        self.feature_net = feature_net
        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        
    #def print_param_grad_info(self, param:Tensor,direct):
        
        
    def get_grad_num(self, grad:Tensor,direct):
        total = 0
        positive = 0
        negative = 0
        res_positive = 0
        res_negative = 0
        res = 0
        for i in range(grad.shape[0]):
            for j in range(grad.shape[1]):
                if grad[i][j] != 0 :
                    total+=1
                    if grad[i][j] > 0:
                        positive+=1
                        if direct[j]!=0:
                            res+=1
                            res_positive+=1
                    if grad[i][j] < 0:
                        negative+=1
                        if direct[j]!=0:
                            res+=1
                            res_negative+=1
        print('total:',total,',positive:',positive,',negative:',negative, ',res:',res,',res_positive:',res_positive,',res_negative:',res_negative)
    
    def forward(self, x):
        self.feature = self.feature_net(x)
        return self.fc(self.feature)
    
    def feature_forward(self, x):
        return self.feature_net(x)
    
    def init_feature_net(self):
        for n,p in self.feature_net.named_parameters():
            if 'weight' in n:
                nn.init.constant_(p.data,0)
            #if 'bias' in n:
            #    nn.init.constant_(p.data,0.1)
    
    def trs(self, features:Tensor):
        features = features.squeeze()
        h_value = torch.zeros([features.shape[0]])
        for i in range(features.shape[0]):
            h_value[i] = torch.var(features[i])
        return h_value.sum()/features.shape[0]
    
    def normalize(self, feature:Tensor, scale = 1, bias = 0):
        max = feature.max()
        min = feature.min()
        l = max-min
        return (scale*(feature-min)/l)+bias
    
    def fixByName(self, filter=['fc']):
        for n,p in self.named_parameters():
            if filter!=None:
                flag=False
                for i in filter:
                    if i in n:
                        flag=True
                        break
                if not flag:
                    p.requires_grad=False
            else:
                p.requires_grad = False
                
    def isPositiveGradient(self):
        for n,p in self.feature_net.named_parameters():
            if 'weight' in n:
                res = p.grad>0
                val = res.flatten().float().sum()
                if val>0:
                    #print(res.flatten().float().sum())
                    print(n,val)
                    #return val
        return False
    
class MLP(nn.Module):
    
    def __init__(self, in_features, hidden_features, hidden_num, out_features) -> None:
        super().__init__()
        self.hidden_num = hidden_num
        if hidden_num<1:
            self.fc1 = nn.Linear(in_features, out_features)
            self.init(self.fc1.weight,0)
        else:
            self.in_layer = nn.Linear(in_features, hidden_features)
            hidden_layer = []
            for i in range(hidden_num-1):
                hidden_layer.append(nn.Linear(hidden_features,hidden_features))
            if hidden_num>1:
                self.hidden_layer = nn.Sequential(*hidden_layer)
            self.out_layer = nn.Linear(hidden_features,out_features)
        #self.fc2 = nn.Linear(hidden_features, out_features)
        self.relu = nn.ReLU()
        #self.init(self.fc1.bias,1)
        #self.init(self.fc2.weight,0)
        #self.init(self.fc2.bias,0)
        #self.init_param()
        self.feature=None
        
    def init_param(self):
        for n,p in self.named_parameters():
            if 'weight' in n:
                nn.init.constant_(p.data,0)
        
    def forward(self, x):
        if self.hidden_num < 1:
            return self.relu(self.fc1(x))
        elif self.hidden_num < 2:
            if self.feature == None:
                self.feature=[]
                self.feature.append(x.detach().clone())
                f = self.relu(self.in_layer(x))
                self.feature.append(f.detach().clone())
                return self.out_layer(f)
            else:
                self.feature[0]=x.detach().clone()
                f = self.relu(self.in_layer(x))
                self.feature[1]=f.detach().clone()
                self.feature.append(f.detach().clone())
                return self.out_layer(f)
        else:
            if self.feature == None:
                self.feature=[]
                #self.feature.append(x.detach().clone())
                f = self.relu(self.in_layer(x))
                #self.feature.append(f.detach().clone())
                f = self.relu(self.hidden_layer(f))
                #self.feature.append(f.detach().clone())
                return self.out_layer(f)
            else:
                #self.feature[0]=x.detach().clone()
                f = self.relu(self.in_layer(x))
                #self.feature[1]=f.detach().clone()
                f = self.relu(self.hidden_layer(f))
                #self.feature[2]=f.detach().clone()
                return self.out_layer(f)
            #return self.out_layer(self.relu(self.hidden_layer(self.relu(self.in_layer(x)))))
    
    def forward_with_feature(self,x):
        f1 = self.fc1(x)
    
    def init(self, param, value):
        nn.init.constant_(param, value)