import torch
import torch.nn as nn
import torch.nn.functional as F
import os


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out

def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


class ResNet(nn.Module):
    def __init__(self, block=BasicBlock, num_blocks=[2,2,2,2], num_classes=10, embed_only=False):
        super(ResNet, self).__init__()
        self.embed_only=embed_only
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        if self.embed_only:
            return out
        out = self.fc(out)
        return out


class ResNet18(nn.Module):
    def __init__(self, num_classes=10,  embed_only=False, from_scratch=False, path_to_weights="pnn/resnet18.pth", device="cuda"):
        super(ResNet18, self).__init__()

        self.resnet18 = ResNet(num_classes=num_classes,  embed_only=embed_only).to(device)
        if not from_scratch:
            # weights = torch.load(path_to_weights, map_location=device)    
            current_working_directory = os.getcwd()

            print(current_working_directory)
            weights = torch.load(path_to_weights, map_location=torch.device('cpu'))
            state_dict = {k: v for k, v in weights.items() if k in self.resnet18.state_dict().keys() and k!="fc.weight" and k!="fc.bias"}
            self.resnet18.load_state_dict(state_dict, strict=False)

    def forward(self, x):
        return self.resnet18(x)

class ResNet34(nn.Module):
    def __init__(self, num_classes=10, embed_only=False, from_scratch=False, path_to_weights="pnn/resnet34.pth", device="cuda"):
        super(ResNet34, self).__init__()

        self.resnet34 = ResNet(BasicBlock, num_blocks=[3, 4, 6, 3], num_classes=num_classes, embed_only=embed_only).to(device)
        if not from_scratch:
            weights = torch.load(path_to_weights, map_location=torch.device('cpu'))
            state_dict = {
                k: v for k, v in weights.items() 
                if k in self.resnet34.state_dict().keys() 
                and k not in ["fc.weight", "fc.bias"]
                }
            self.resnet34.load_state_dict(state_dict, strict=False)

    def forward(self, x):
        return self.resnet34(x)

class ResNet50(nn.Module):
    def __init__(self, num_classes=10,  embed_only=False, from_scratch=False, path_to_weights="pnn/resnet50.pth", device="cuda"):
        super(ResNet50, self).__init__()

        self.resnet50 = ResNet(Bottleneck, num_blocks=[3,4,6,3], num_classes=num_classes,  embed_only=embed_only).to(device)
        if not from_scratch:
            # weights = torch.load(path_to_weights, map_location=device)
            weights = torch.load(path_to_weights, map_location=torch.device('cpu'))
            state_dict = {k: v for k, v in weights.items() if k in self.resnet50.state_dict().keys() and k!="fc.weight" and k!="fc.bias"}
            self.resnet50.load_state_dict(state_dict, strict=False)

    def forward(self, x):
        return self.resnet50(x)
    

class ResNet32_frame(nn.Module):
    def __init__(self, num_classes=10, embed_only=False):
        super(ResNet32_frame, self).__init__()
        self.in_planes = 64  
        self.embed_only = embed_only

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(BasicBlock, 64, 5, stride=1)   
        self.layer2 = self._make_layer(BasicBlock, 128, 5, stride=2)  
        self.layer3 = self._make_layer(BasicBlock, 256, 5, stride=2)  
        self.layer4 = self._make_layer(BasicBlock, 512, 5, stride=2)  
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        if self.embed_only:
            return out
        return self.fc(out)
    
class ResNet32(nn.Module):
    def __init__(self, num_classes=10, embed_only=False, from_scratch=False, path_to_weights=None, device="cuda"):
        super(ResNet32, self).__init__()

        self.resnet32 = ResNet32_frame(num_classes=num_classes, embed_only=embed_only).to(device)
        if not from_scratch and path_to_weights:
            weights = torch.load(path_to_weights, map_location=torch.device('cpu'))
            state_dict = {k: v for k, v in weights.items() if k in self.resnet32.state_dict().keys() and k != "fc.weight" and k != "fc.bias"}
            self.resnet32.load_state_dict(state_dict, strict=False)

    def forward(self, x):
        return self.resnet32(x)


class Projector(nn.Module):
    def __init__(self, name='resnet18', out_dim=128, apply_bn=False, device="gpu"):
        super(Projector, self).__init__()
        _, dim_in = model_dict[name]
        self.linear1 = nn.Linear(dim_in, dim_in)
        self.linear2 = nn.Linear(dim_in, out_dim)
        self.bn = nn.BatchNorm1d(dim_in)
        self.relu = nn.ReLU()
        if apply_bn:
            self.projector = nn.Sequential(self.linear1, self.bn, self.relu, self.linear2)
        else:
            self.projector = nn.Sequential(self.linear1, self.relu, self.linear2)
        self.projector = self.projector.to(device)

    def forward(self, x):
        return self.projector(x)



class LinearClassifier(nn.Module):
    def __init__(self, name='resnet18', num_classes=10, device="cuda"):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes).to(device)

    def forward(self, features):
        return self.fc(features)
    



class ResidualBlock(nn.Module):

    def __init__(self, channel):
        super().__init__()
        self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1)

    def forward(self, x):

        y = F.relu(self.conv1(x))
        y = self.conv2(y)

        return F.relu(x + y)

class Net4mnist(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.res_block_1 = ResidualBlock(16)
        self.res_block_2 = ResidualBlock(32)
        self.conv2_drop = nn.Dropout2d()

    def forward(self, x):
        in_size = x.size(0)
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = self.res_block_1(x)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = self.res_block_2(x)
        x = x.view(in_size, -1)
        return x


class Decoder1(nn.Module):
    def __init__(self, in_channels, out_channels=3, initial_size=7, target_size=224):
        super(Decoder1, self).__init__()
        
        self.initial_size = initial_size
        self.target_size = target_size
        self.initial_channels = in_channels // (initial_size * initial_size)
        
      
        num_upsamples = 0
        current_size = initial_size
        while current_size * 2 <= target_size:
            current_size *= 2
            num_upsamples += 1
            
        layers = [
            nn.Linear(in_channels, self.initial_channels * initial_size * initial_size),
            nn.BatchNorm1d(self.initial_channels * initial_size * initial_size),
            nn.ReLU(),
            nn.Unflatten(1, (self.initial_channels, initial_size, initial_size))
        ]
        
        current_channels = self.initial_channels
        for i in range(num_upsamples):
            out_channels_conv = max(current_channels // 2, 64) if i < num_upsamples - 1 else 64
            layers.extend([
                nn.ConvTranspose2d(current_channels, out_channels_conv, 4, 2, 1),
                nn.BatchNorm2d(out_channels_conv),
                nn.ReLU()
            ])
            current_channels = out_channels_conv
            
        layers.extend([
            nn.Conv2d(current_channels, out_channels, 1, 1, 0),
            nn.Tanh()
        ])
        
        self.decoder = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.decoder(x)
        if out.shape[-1] != self.target_size:
            out = F.interpolate(out, size=(self.target_size, self.target_size), 
                              mode='bilinear', align_corners=True)
        return out
    
class Decoder2(nn.Module):
    def __init__(self, in_channels, out_channels=3, initial_size=7):
        super(Decoder2, self).__init__()
        
        self.initial_size = initial_size
        

        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels * initial_size * initial_size),
            nn.BatchNorm1d(in_channels * initial_size * initial_size),
            nn.ReLU()
        )
        
        layers = []
        current_size = initial_size
        current_channels = in_channels
        

        while current_size < 224:  
            out_channels = current_channels // 2
            layers.extend([
                nn.ConvTranspose2d(current_channels, out_channels, 4, 2, 1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            ])
            current_size *= 2
            current_channels = out_channels
            
        layers.extend([
            nn.ConvTranspose2d(current_channels, out_channels, 3, 1, 1),
            nn.Tanh()
        ])
        
        self.decoder = nn.Sequential(*layers)
        
    def forward(self, x):
       
        x = self.fc(x)
       
        x = x.view(-1, self.initial_channels, self.initial_size, self.initial_size)
        return self.decoder(x)

class AutoEncoder(nn.Module):
    """Auto Encoder Models"""
    def __init__(self, encoder_name='resnet18', num_classes=10, embed_only=False, 
                 from_scratch=False, path_to_weights="pnn/resnet34.pth", 
                 device="cuda", target_size=224):  
        super(AutoEncoder, self).__init__()
        

        if encoder_name not in model_dict:
            raise ValueError(f"Unsupported encoder: {encoder_name}")
            
        encoder_class, feat_dim = model_dict[encoder_name]
        

        if encoder_name == 'resnet18':
            self.encoder = ResNet18(num_classes=num_classes, 
                       embed_only=embed_only,
                       from_scratch=from_scratch, 
                       path_to_weights=path_to_weights,
                       device=device)
            initial_size = 7
        elif encoder_name == 'resnet32':
            self.encoder = ResNet32(num_classes=num_classes, 
                       embed_only=embed_only,
                       from_scratch=from_scratch, 
                       path_to_weights=path_to_weights,
                       device=device)
            initial_size = 8
        elif encoder_name == 'resnet34':  
            self.encoder = ResNet34(num_classes=num_classes, 
                       embed_only=embed_only,
                       from_scratch=from_scratch, 
                       path_to_weights=path_to_weights,
                       device=device)
            initial_size = 7  
        elif encoder_name == 'resnet50':
            self.encoder = ResNet50(num_classes=num_classes, 
                       embed_only=embed_only,
                       from_scratch=from_scratch, 
                       path_to_weights=path_to_weights,
                       device=device)
            initial_size = 7
        else:
            raise ValueError(f"Unsupported encoder: {encoder_name}")
        
        self.encoder_name = encoder_name  

        # set decoder
        self.decoder = Decoder1(
            in_channels=feat_dim,
            out_channels=3, 
            initial_size=initial_size,
            target_size=target_size  
        ).to(device)
        
    def forward(self, x, return_features=False):
        features = self.encoder(x)  # [batch_size, feat_dim]
        reconstructed = self.decoder(features)  # [batch_size, 3, H, W]
        if return_features:
            return reconstructed, features
        return reconstructed  

    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, features):
        return self.decoder(features)
        
    def get_latent_size(self):
        if self.encoder_name == 'Net4mnist':
            return 512
        elif self.encoder_name == 'resnet18':
            return 512
        elif self.encoder_name == 'resnet32':
            return 512
        elif self.encoder_name == 'resnet34':
            return 512
        elif self.encoder_name == 'resnet50':
            return 2048

model_dict = {
    'resnet18': [ResNet18, 512],
    'resnet32': [ResNet32, 512],  
    'resnet34': [ResNet34, 512],
    'resnet50': [ResNet50, 2048],
    'Net4mnist': [Net4mnist, 512],
    'autoencoder': [AutoEncoder, 512],
}

