# This part is based on https://github.com/huawei-noah/Data-Efficient-Model-Compression

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

    def functional_forward(self, x, params, prefix):
        out = F.conv2d(x, params[prefix + '.conv1.weight'], stride=self.conv1.stride, padding=self.conv1.padding)
        out = F.batch_norm(out, running_mean=None, running_var=None, weight=params[prefix + '.bn1.weight'],
                           bias=params[prefix + '.bn1.bias'], training=True)
        out = F.relu(out)

        out = F.conv2d(out, params[prefix + '.conv2.weight'], stride=self.conv2.stride, padding=self.conv2.padding)
        out = F.batch_norm(out, running_mean=None, running_var=None, weight=params[prefix + '.bn2.weight'],
                           bias=params[prefix + '.bn2.bias'], training=True)
        shortcut = x
        if len(self.shortcut) != 0:
            shortcut = F.conv2d(x, params[prefix + '.shortcut.0.weight'], stride=self.shortcut[0].stride)
            shortcut = F.batch_norm(shortcut, running_mean=None, running_var=None,
                                    weight=params[prefix + '.shortcut.1.weight'],
                                    bias=params[prefix + '.shortcut.1.bias'], training=True)

        out = F.relu(out + shortcut)
        return out


class PolicyOut(nn.Module):
    def __init__(self, in_planes, policy_channels=10):
        super(PolicyOut, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv = nn.Conv2d(in_planes, policy_channels, kernel_size=1, stride=1, padding=0)
        self.linear = nn.Linear(in_planes * 8 * 8, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        piece_policy = self.conv(x).flatten(start_dim=1)
        exit_policy = self.linear(x.flatten(start_dim=1))
        return torch.cat([piece_policy, exit_policy], dim=1)


class ChessResNet(nn.Module):
    def __init__(self, planes_list, num_classes=3, out_method="ResNet", torso=None, in_planes=None):
        super(ChessResNet, self).__init__()

        self.in_planes = in_planes
        if torso is None:
            layers = list()
            layers.append(nn.Conv2d(12, planes_list[0], kernel_size=3, stride=1, padding=1, bias=False))
            layers.append(nn.BatchNorm2d(planes_list[0]))
            self.in_planes = planes_list[0]
            layers.append(nn.ReLU())

            self.res_blocks = self._make_layer(planes_list[1:], layers=layers)
        else:
            self.res_blocks = torso
        # self.conv1 = nn.Conv2d(12, planes_list[0], kernel_size=3, stride=1, padding=1, bias=False)
        # self.bn1 = nn.BatchNorm2d(planes_list[0])
        # self.in_planes = planes_list[0]

        # self.res_blocks = self._make_layer(planes_list[1:])
        if out_method == "ResNet":
            self.out_layer = nn.Sequential(
                nn.AvgPool2d(8),
                nn.Flatten(start_dim=1),
                nn.Linear(self.in_planes, num_classes)
            )
        elif out_method == "Linear":
            self.out_layer = nn.Sequential(
                nn.Conv2d(self.in_planes, 1, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(1),
                nn.ReLU(),
                nn.Flatten(start_dim=1),
                nn.Linear(8 * 8, self.in_planes),
                nn.ReLU(),
                nn.Linear(self.in_planes, num_classes),
            )
        elif out_method == "PF":
            # assert num_classes == 10 * 64 + 1
            self.out_layer = PolicyOut(self.in_planes, policy_channels=num_classes // 64)
        else:
            # assert out_method == "PB" and num_classes == 10 * 64
            self.out_layer = nn.Sequential(
                nn.Conv2d(self.in_planes, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(self.in_planes),
                nn.ReLU(),
                nn.Conv2d(self.in_planes, num_classes // 64, kernel_size=1, stride=1, padding=0),
                nn.Flatten(start_dim=1)
            )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, planes_list, layers=list()):
        for planes in planes_list:
            layers.append(BasicBlock(self.in_planes, planes, stride=1))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x, out_feature=False):
        # out = F.relu(self.bn1(self.conv1(x.view(-1, 12, 8, 8))))
        # out = self.res_blocks(out)
        # print(out.shape)
        # print(f"ChessResNet forward on shape {x.shape}")
        out = self.res_blocks(x.view(-1, 12, 8, 8))
        # print(f"Out forward on shape {out.shape}")
        out = self.out_layer(out)
        # print(f"Out shape: {out.shape}")
        return out
