from math import sqrt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.nn import init


def get_model(model_name, out_features, in_channels, arch_params):
    preloaded_models = {"ResNet18": torchvision.models.resnet18}

    own_models = {"ConvNet": ConvNet, "MLP": MLP, "PureConvNet": PureConvNet, "CombResnet18": CombRenset18,
    "PartialResNet": partialResNetSeperated, "PartialResNetMTL": PartialResNetMTL,"MLPMTL": MLPMTL}

    if model_name in preloaded_models:
        model = preloaded_models[model_name](pretrained=False, num_classes=out_features, **arch_params)

        # Hacking ResNets to expect 'in_channels' input channel (and not three)
        del model.conv1
        model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        return model
    elif model_name in own_models:
        model = own_models[model_name](out_features=out_features, in_channels=in_channels, **arch_params)
        if model_name == 'MLPMTL': 
            init_weights(model, 'xavier')
        return model
    else:
        raise ValueError(f"Model name {model_name} not recognized!")


def dim_after_conv2D(input_dim, stride, kernel_size):
    return (input_dim - kernel_size + 2) // stride



def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming_uniform':
                init.kaiming_uniform(m.weight.data, a=0, mode='fan_in')
                init.kaiming_uniform(m.bias.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
                init.orthogonal_(m.bias.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find(
                'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)




class CombRenset18(nn.Module):

    def __init__(self, out_features, in_channels, n_task):
        super().__init__()
        self.resnet_model = torchvision.models.resnet18(pretrained=False, num_classes=out_features)
        del self.resnet_model.conv1
        self.resnet_model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        output_shape = (int(sqrt(out_features)), int(sqrt(out_features)))
        self.pool = nn.AdaptiveMaxPool2d(output_shape)
        #self.last_conv = nn.Conv2d(128, 1, kernel_size=1,  stride=1)


    def forward(self, x):
        x = self.resnet_model.conv1(x)
        x = self.resnet_model.bn1(x)
        x = self.resnet_model.relu(x)
        x = self.resnet_model.maxpool(x)
        x = self.resnet_model.layer1(x)
        x = self.pool(x)
        x = x.mean(dim=1)
        return x


class ConvNet(torch.nn.Module):
    def __init__(self, out_features, in_channels, kernel_size, stride, linear_layer_size, channels_1, channels_2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels_1, kernel_size=kernel_size, stride=stride)
        self.conv2 = nn.Conv2d(in_channels=channels_1, out_channels=channels_2, kernel_size=kernel_size, stride=stride)

        output_shape = (4, 4)
        self.pool = nn.AdaptiveAvgPool2d(output_shape)

        self.fc1 = nn.Linear(in_features=output_shape[0] * output_shape[1] * channels_2, out_features=linear_layer_size)
        self.fc2 = nn.Linear(in_features=linear_layer_size, out_features=out_features)

    def forward(self, x):
        batch_size = x.shape[0]
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(batch_size, -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class MLP(torch.nn.Module):
    def __init__(self, out_features, in_channels, hidden_layer_size):
        super().__init__()
        input_dim = in_channels * 40 * 20
        self.fc1 = nn.Linear(in_features=input_dim, out_features=hidden_layer_size)
        self.fc2 = nn.Linear(in_features=hidden_layer_size, out_features=out_features)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        return x


class PureConvNet(torch.nn.Module):

    act_funcs = {"relu": F.relu, "tanh": F.tanh, "identity": lambda x: x}

    def __init__(self, out_features, pooling, use_second_conv, kernel_size, in_channels,
                 channels_1=20, channels_2=20, act_func="relu"):
        super().__init__()
        self.use_second_conv = use_second_conv

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=channels_1, kernel_size=kernel_size, stride=1)
        self.conv2 = nn.Conv2d(in_channels=channels_1, out_channels=channels_2, kernel_size=kernel_size, stride=1)

        output_shape = (int(sqrt(out_features)), int(sqrt(out_features)))
        if pooling == "average":
            self.pool = nn.AdaptiveAvgPool2d(output_shape)
        elif pooling == "max":
            self.pool = nn.AdaptiveMaxPool2d(output_shape)

        self.conv3 = nn.Conv2d(in_channels=channels_2 if use_second_conv else channels_1,
                               out_channels=1, kernel_size=1, stride=1)
        self.act_func = PureConvNet.act_funcs[act_func]

    def forward(self, x):
        x = self.act_func(self.conv1(x))
        if self.use_second_conv:
            x = self.act_func(self.conv2(x))
        x = self.pool(x)
        x = self.conv3(x)
        return x




#M 5/8
# shared some layers then have separate tower such that tower do 1 task => use for multi-task
class PartialResNetMTL(nn.Module):
    """
    Truncated ResNet18 with multiple towers
    """
    def __init__(self, out_features, in_channels, n_task):
        super(PartialResNetMTL, self).__init__()
        # init resnet 18
        k = int(sqrt(out_features))
        resnet = torchvision.models.resnet18(pretrained=False)
        # first five layers of ResNet18 as shared layer
        self.conv = resnet.conv1
        self.bn = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.block = resnet.layer1[0]
        # towers layer
        self.towers = nn.ModuleList([])
        for _ in range(n_task):
            # basic block
            block = torchvision.models.resnet18(pretrained=False).layer1[1]
            # conv to 1 channel
            conv  = nn.Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            # max pooling
            maxpool = nn.AdaptiveMaxPool2d((k,k))
            # tower
            tower = nn.Sequential(block, conv, maxpool)
            self.towers.append(tower)

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h)
        h = self.relu(h)
        h = self.maxpool(h)
        h = self.block(h)
        outs = []
        for tower in self.towers:
            out = tower(h)
            # reshape for optmodel
            out = torch.squeeze(out, 1)
            out = out.reshape(out.shape[0], -1)
            outs.append(out)
        return torch.stack(outs)


#no share layer

class partialResNetSeperated(nn.Module):
    """
    Mutiple seperated truncated ResNet18
    """
    def __init__(self, out_features, in_channels, n_task):
        super(partialResNetSeperated, self).__init__()
        # init resnet 18
        out_features = int(sqrt(out_features))
        # print('resnet',out_features)
        resnet = torchvision.models.resnet18(pretrained=False)
        # first five layers of ResNet18 as shared layer
        conv1 = resnet.conv1
        bn = resnet.bn1
        relu = resnet.relu
        maxpool1 = resnet.maxpool
        blocks = resnet.layer1
        # conv to 1 channel
        conv2  = nn.Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        # max pooling
        maxpool2 = nn.AdaptiveMaxPool2d((out_features,out_features))
        # tower
        self.tower = nn.Sequential(conv1, bn, relu, maxpool1, blocks, conv2, maxpool2)

    def forward(self, x):
        out = self.tower(x)
        # reshape for optmodel
        out = torch.squeeze(out, 1)
        # out = out.reshape(out.shape[0], -1)
        return out



class MLPMTL(nn.Module):
    def __init__(self, out_features, in_channels, n_task):
        super().__init__()
        input_dim = 3*96*96
        self.sharedlayer = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(), 
            nn.Linear(in_features=1024,out_features=512),
            nn.BatchNorm1d(512) ,
            nn.ReLU(),
            )
        self.towers = nn.ModuleList([])
        for _ in range(n_task): 
            cur_model = nn.Sequential(nn.Linear(in_features=512, out_features=out_features),
                nn.BatchNorm1d(out_features), 
                nn.ReLU(), 
                nn.Linear(in_features=out_features, out_features=out_features))
            self.towers.append(cur_model)
        print('self.sharedlayer', self.sharedlayer)
                

    def forward(self, x):
        batch_size = x.shape[0]
        # print('x',x.shape)
        x = x.reshape(batch_size, -1)
        x = self.sharedlayer(x)
        outs = []
        for tower in self.towers:
            out = tower(x)
            outs.append(out)
        return torch.stack(outs)



