import torch
from torch import nn
import torch.nn.functional as F
import einops

import model.resnet as models
from PixelDiff.sensor import *

class PPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(PPM, self).__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)


class PSPNet(nn.Module):
    def __init__(self, layers=50, bins=(1, 2, 3, 6), dropout=0.1, classes=2, zoom_factor=8, use_ppm=True,
                criterion=nn.CrossEntropyLoss(ignore_index=255), pretrained=True, sensor_h=257, sensor_w=513, deform='sigmoid',
                constrain_t = True, constrain_factor=None, spline_args=12):
        super(PSPNet, self).__init__()
        assert layers in [50, 101, 152]
        assert 2048 % len(bins) == 0
        assert classes > 1
        assert zoom_factor in [1, 2, 4, 8]
        self.zoom_factor = zoom_factor
        self.use_ppm = use_ppm
        self.criterion = criterion

        if deform is None or deform == 'sigmoid':
            self.deform = IndependentAnisotropicHalfNormalTunableSigmoidDeformation(1.0)
        elif deform == 'circular_sigmoid':
            self.deform = AnisotropicHalfNormalTunableSigmoidDeformation(1.0)
        else:
            self.deform = Monotone_Linear_Spline_Deformation(spline_args)

        if constrain_factor is None:
            constrain_factor = 0.6
        self.sensor = FoveatedSensor((sensor_h, sensor_w), deform=self.deform, constrain_t=constrain_t, constrain_factor=constrain_factor)

        if layers == 50:
            resnet = models.resnet50(pretrained=pretrained)
        elif layers == 101:
            resnet = models.resnet101(pretrained=pretrained)
        else:
            resnet = models.resnet152(pretrained=pretrained)
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.conv2, resnet.bn2, resnet.relu, resnet.conv3, resnet.bn3, resnet.relu, resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        fea_dim = 2048
        if use_ppm:
            self.ppm = PPM(fea_dim, int(fea_dim/len(bins)), bins)
            fea_dim *= 2
        self.cls = nn.Sequential(
            nn.Conv2d(fea_dim, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            nn.Conv2d(512, classes, kernel_size=1)
        )
        if self.training:
            self.aux = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout),
                nn.Conv2d(256, classes, kernel_size=1)
            )

    def forward(self, x, y=None):
        h_prime, w_prime = x.shape[2], x.shape[3]
        x = self.sensor(x)

        if not self.training:
            sensor_img = x.clone()

        x_size = x.size()
        assert (x_size[2]-1) % 8 == 0 and (x_size[3]-1) % 8 == 0
        h = int((x_size[2] - 1) / 8 * self.zoom_factor + 1)
        w = int((x_size[3] - 1) / 8 * self.zoom_factor + 1)

        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)
        if self.use_ppm:
            x = self.ppm(x)
        x = self.cls(x)
        if self.zoom_factor != 1:
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)

        if self.training:
            aux = self.aux(x_tmp)
            if self.zoom_factor != 1:
                aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)

            h_prime, w_prime = y.shape[1], y.shape[2]
            step_x = (1 + 1) / w_prime
            step_y = (1 + 1) / h_prime
            with torch.no_grad():
                pixel_pos_x = torch.arange(-1, 1, step_x, device=x.device)
                pixel_pos_y = torch.arange(-1, 1, step_y, device=x.device)
                pixel_pos_x, pixel_pos_y = torch.meshgrid(pixel_pos_x, pixel_pos_y, indexing='xy')
                pixel_pos = torch.stack((pixel_pos_x, pixel_pos_y), dim=2)
                pixel_pos = self.sensor.backwardDeform(pixel_pos)
                pixel_pos = einops.repeat(pixel_pos, 'h w c -> b h w c', b=x.shape[0])
            x = F.grid_sample(x, pixel_pos, mode='nearest', align_corners=False)
            aux = F.grid_sample(aux, pixel_pos, mode='nearest', align_corners=False)
            main_loss = self.criterion(x, y)
            aux_loss = self.criterion(aux, y)
            return x.max(1)[1], main_loss, aux_loss, y
        else:
            step_x = (1 + 1) / w_prime
            step_y = (1 + 1) / h_prime
            with torch.no_grad():
                pixel_pos_x = torch.arange(-1, 1, step_x, device=x.device)
                pixel_pos_y = torch.arange(-1, 1, step_y, device=x.device)
                pixel_pos_x, pixel_pos_y = torch.meshgrid(pixel_pos_x, pixel_pos_y, indexing='xy')
                pixel_pos = torch.stack((pixel_pos_x, pixel_pos_y), dim=2)
                pixel_pos = self.sensor.backwardDeform(pixel_pos)
                pixel_pos = einops.repeat(pixel_pos, 'h w c -> b h w c', b=x.shape[0])
            x = F.grid_sample(x, pixel_pos, mode='nearest', align_corners=False)
            sensor_img = F.grid_sample(sensor_img, pixel_pos, mode='nearest', align_corners=False)
            return x, sensor_img


if __name__ == '__main__':
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1'
    input = torch.rand(4, 3, 257, 513).cuda()
    model = PSPNet(layers=50, bins=(1, 2, 3, 6), dropout=0.1, classes=21, zoom_factor=8, use_ppm=True, pretrained=True).cuda()
    print(model.parameters())
    model.training = True
    print(model)
    for i in range(100):
        output, main_loss, aux_loss, y = model(input, torch.rand(4, 257, 513).long().cuda())
        model.zero_grad()
    print('PSPNet', output.size())
