import torch
import torch.nn as nn
import torchvision


class ResNet(nn.Module):
    def __init__(self, num_classes, pretrain=False, in_channel=20, net_idx='18'):
        super(ResNet, self).__init__()

        if net_idx not in ['18', '34', '50', '101', '152']:
            raise Exception('Valid ResNet idx: 18, 34, 50, 101, 152')

        self.in_channel = in_channel

        net_zoo = {'18': torchvision.models.resnet18(num_classes=1000, pretrained=pretrain),
                   '34': torchvision.models.resnet34(num_classes=1000, pretrained=pretrain),}
                   # '50': torchvision.models.resnet50(num_classes=1000, pretrained=pretrain),
                   # '101': torchvision.models.resnet101(num_classes=1000, pretrained=pretrain),
                   # '152': torchvision.models.resnet152(num_classes=1000, pretrained=pretrain)}
        resnet = net_zoo[net_idx]

        inconv = nn.Conv2d(in_channel, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        if pretrain:
            if in_channel == 3:
                inconv = resnet.conv1
            else:
                inconv_weight = resnet.state_dict()['conv1.weight']  # (64, 3, 7, 7)
                transformed_weight = self.weight_transform(inconv_weight)  # (64, 20, 7, 7)
                inconv.state_dict()['weight'] = transformed_weight
        bn1 = resnet.bn1
        relu = resnet.relu
        maxpool = resnet.maxpool
        layer1 = resnet.layer1
        layer2 = resnet.layer2
        layer3 = resnet.layer3
        layer4 = resnet.layer4
        avg_pool = resnet.avgpool
        flatten = nn.Flatten(1)
        fc = nn.Linear(resnet.fc.in_features, num_classes)
        self.model = nn.Sequential(inconv, bn1, relu, maxpool,
                                   layer1, layer2, layer3, layer4, avg_pool, flatten, fc)
        self.backbone = nn.Sequential(inconv, bn1, relu, maxpool,
                                      layer1, layer2, layer3, layer4, avg_pool, flatten)

    def weight_transform(self, in_weight):
        avg_weight = in_weight.mean(dim=1)  # take average along the channel axis
        new_weight = torch.FloatTensor(64, self.in_channel, 7, 7)
        for i in range(self.in_channel):
            new_weight[:, i, :, :] = avg_weight
        return new_weight

    def forward(self, x):
        output = self.model(x)
        return output


class FusionModel(nn.Module):
    def __init__(self, num_classes, pretrain, distill=False, model_1=ResNet, model_2=ResNet,
                 flow_in_channel=20, rgb_in_channel=3, net_idx='18'):
        super(FusionModel, self).__init__()

        self.distill = distill  # return features for UMT

        flow_model = model_1(num_classes=num_classes, pretrain=pretrain, in_channel=flow_in_channel, net_idx=net_idx)
        rgb_model = model_2(num_classes=num_classes, pretrain=pretrain, in_channel=rgb_in_channel, net_idx=net_idx)

        self.flow_backbone = flow_model.backbone
        self.rgb_backbone = rgb_model.backbone
        self.linear = nn.Linear(1024, num_classes)

    def forward(self, flow, rgb):
        flow_feature = self.flow_backbone(flow)  # (batch_size, 512)
        rgb_feature = self.rgb_backbone(rgb)  # (batch_size, 512)

        joint_feature = torch.cat((flow_feature, rgb_feature), 1)
        output = self.linear(joint_feature)

        if self.distill:
            return output, flow_feature, rgb_feature
        else:
            return output


class Encoder(nn.Module):
    def __init__(self, in_channel, pretrain, net_idx='18'):
        super(Encoder, self).__init__()

        resnet = ResNet(101, pretrain=pretrain, in_channel=in_channel, net_idx=net_idx)
        self.encoder = resnet.backbone

    def forward(self, x):
        output = self.encoder(x)
        return output


class LinearClassifier(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        output = self.linear(x)
        return output


if __name__ == '__main__':
    #net = ResNet(101, pretrain=True, in_channel=3, net_idx='18')
    net = FusionModel(101, pretrain=True, model_1=ResNet, model_2=ResNet)
    print(net)
