import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class STN(nn.Module):
    def __init__(self, in_plane, k=2, img_size=32):
        super(STN, self).__init__()
        self.k = k
        self.d_c = ((img_size - 6) // 2 - 4) // 2  # size of convolution output
        self.img_size = img_size

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(in_plane, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the two 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * self.d_c * self.d_c, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2 * self.k)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0]*self.k, dtype=torch.float))

    # Spatial transformer network forward function
    def forward(self, x, downsample=False):
        if downsample:
            x = F.interpolate(input=x, size=self.img_size, mode='bilinear', align_corners=True)
        # torch.Size([1, 2, 32, 32])
        xs = self.localization(x)
        # torch.Size([1, 10, 4, 4])
        xs = xs.view(-1, 10 * self.d_c * self.d_c)
        # torch.Size([1, 160])
        theta = self.fc_loc(xs)
        # torch.Size([1, 12])
        theta = theta.view(-1, self.k, 2, 3)
        # torch.Size([2, 2, 3])

        # grid = F.affine_grid(theta, x.size())
        # x = F.grid_sample(x, grid)

        return theta

    # test
    def test_size(self):
        x = torch.rand([1, 2, 32, 32], dtype=torch.float)
        print(x.size())
        xs = self.localization(x)
        print(xs.size())
        xs = xs.view(-1, 10 * 4 * 4)
        print(xs.size())
        theta = self.fc_loc(xs)
        print(theta.size())
        theta = theta.view(-1, 2, 2, 3)
        print(theta.size())


class KMixAugmentor(nn.Module):
    def __init__(self, cam_model, masknet, k=2, img_size=32, s_size=-1, stn_size=-1):
        super(KMixAugmentor, self).__init__()
        self.name = 'kmixaugmentor'
        self.cam_model = cam_model
        self.masknet = masknet
        # self.img_size = (img_size, img_size)
        if s_size == -1:
            self.s_size = (img_size, img_size)
            self.upsample = False
        else:
            self.s_size = (s_size, s_size)  # downsample for faster computation
            self.upsample = True
        if stn_size == -1:
            self.stn_size = self.s_size[0]
        else:
            self.stn_size = stn_size
        # if img_size == -1:
        #     self.upsample = False  # follow saliency size
        self.stn = STN(in_plane=2*k+1, k=k, img_size=self.stn_size).cuda()  # mix 2 images
        self.T = nn.Parameter(torch.tensor(1.))
        self.k = k
        self.stn_b = STN(in_plane=k+1, k=k).cuda()  # mix 2 images
        self.T_b = nn.Parameter(torch.tensor(1.))

    def forward(self, s_list, lam_list, img_size):  # x here is the stack of s maps
        assert len(s_list) == self.k and len(lam_list) == self.k
        # if self.stn_size > 0:
        # rescale s_list for smaller size to match stn
        s_size = s_list[0].shape
        lam_tensor_list = []
        for lam in lam_list:
            lam_tensor = torch.tensor(lam, dtype=torch.float).repeat(s_size).cuda()
            lam_tensor_list.append(lam_tensor)

        z = torch.normal(mean=0.0, std=1.0, size=s_size).cuda()
        szlam = torch.cat([*s_list, *lam_tensor_list, z], dim=1)  # N10

        #  predit the affine parameters
        theta = self.stn(szlam, (self.stn_size > 0))

        #  apply affine transform to the s maps ??
        a_list = []
        ts_list = []
        for i in range(self.k):
            a_list.append(theta[:, i])
            grid = F.affine_grid(theta[:, i], s_size, align_corners=True)
            ts = F.grid_sample(s_list[i], grid, align_corners=True)
            ts_list.append(ts)

        #  predict the mixing masks
        # z = torch.normal(mean=0.0, std=1.0, size=s_size).cuda()
        tslam = torch.cat([*ts_list, *lam_tensor_list], dim=1)

        # for k output
        p_mask = self.masknet(tslam)
        p_mask = F.softmax(p_mask/self.T, dim=1)
        # k_masks = F.softmax(p_mask/self.T, dim=0).unsqueeze(2)
        if self.upsample:
            p_mask = F.interpolate(input=p_mask, size=img_size, mode='bilinear', align_corners=True)

        k_masks = p_mask.permute((1, 0, 2, 3)).unsqueeze(2)

        return k_masks, a_list

    def forward_b(self, s_list):
        assert len(s_list) == self.k
        s_size = s_list[0].shape
        z = torch.normal(mean=0.0, std=1.0, size=s_size).cuda()
        sz = torch.cat([*s_list, z], dim=1)  # N10

        #  predit the affine parameters
        theta = self.stn_b(sz)

        #  apply affine transform to the s maps ??
        a_list = []
        ts_list = []

        for i in range(self.k):
            a_list.append(theta[:, i])
            grid = F.affine_grid(theta[:, i], s_size, align_corners=True)
            ts = F.grid_sample(s_list[i], grid, align_corners=True)
            ts_list.append(ts)

        return a_list, ts_list

    def mix_data_background(self, x, y, perm=None, return_report=False):
        #  get two sequence of data
        assert self.k == 2
        batch_size = x.size()[0]
        index_list = []
        for _ in range(self.k):
            index_list.append(torch.randperm(batch_size))

        # for control experiments
        if perm is not None:
            index_list = perm[1]

        s_list = []
        s = self.cam_model.get_cam(x, y, mixing=False)  # scale to input size
        s = s.unsqueeze(1).cuda()
        for index in index_list:
            s_list.append(s[index, :])

        a_list, ts_list = self.forward_b(s_list)
        tx_list = []
        for a, index in zip(a_list, index_list):
            grid = F.affine_grid(a, x.size(), align_corners=True)
            tx = F.grid_sample(x[index], grid, align_corners=True, padding_mode='reflection')
            tx_list.append(tx)

        #  how to generalize to k images?
        m1 = ts_list[0]  # foreground mask
        m2 = 1 - m1  # background mask
        two_masks = torch.stack([m1, m2])
        two_masks = F.softmax(two_masks/self.T_b, dim=0)

        mixed_x = torch.mul(two_masks, torch.stack([*tx_list]))
        mixed_x = torch.sum(mixed_x, 0)

        if return_report:
            return {'x': x, 'index_list': index_list, 'tx_list': tx_list,
                    'mixed_x': mixed_x, 's_list': s_list, 'k_masks': two_masks,
                    'y': y[index_list[0]], 'a_list': a_list
                    }
        else:
            return mixed_x, y[index_list[0]]

    def mix_data(self, x, y, alpha=3, perm=None, return_report=False, return_mask=False, transfer=False):
        if alpha > 0:
            lam_list = np.random.dirichlet([alpha] * self.k)  # lam: [0.1, 0.3, 0.2, ...]
        else:
            lam_list = [1/self.k] * self.k

        cam_target = None if transfer else y
        batch_size = x.size()[0]
        index_list = []
        for _ in range(self.k):
            index_list.append(torch.randperm(batch_size))

        # for control experiments
        if perm is not None:
            lam_list = perm[0]
            index_list = perm[1]

        s_list = []
        s = self.cam_model.get_cam(x, cam_target, mixing=False, s_size=self.s_size)  # scale to input size
        s = s.unsqueeze(1).cuda()
        for index in index_list:
            s_list.append(s[index, :])

        k_masks, a_list = self.forward(s_list, lam_list, img_size=x.size()[-1])
        tx_list = []
        for a, index in zip(a_list, index_list):
            # print(a)
            # print(a.size())
            # print(x.size())
            grid = F.affine_grid(a, x.size(), align_corners=True)
            tx = F.grid_sample(x[index], grid, align_corners=True, padding_mode='border')
            tx_list.append(tx)
        # import IPython; IPython.embed(); exit(1)
        mixed_x = torch.mul(k_masks, torch.stack([*tx_list]))
        mixed_x = torch.sum(mixed_x, 0)

        if return_report:
            if self.upsample:
                s_list = []
                s = self.cam_model.get_cam(x, cam_target, mixing=False, s_size=x.size()[-1])  # scale to input size
                s = s.unsqueeze(1).cuda()
                for index in index_list:
                    s_list.append(s[index, :])

            return {'x': x, 'index_list': index_list, 'tx_list': tx_list,
                    'mixed_x': mixed_x, 's_list': s_list, 'k_masks': k_masks,
                    'lam_list': lam_list
                    }
        elif return_mask:
            r_masks = k_masks.squeeze(2)
            r_sums = torch.sum(r_masks, dim=(2, 3))
            r_norm = torch.sum(r_sums, dim=0)
            r_list = r_sums / r_norm
            return mixed_x, y, index_list, r_list, lam_list
        else:
            return mixed_x, y, index_list, lam_list

    def test_forward_k3(self):
        x1 = torch.rand([3, 1, 32, 32], dtype=torch.float).cuda()
        x_list = [x1, x1, x1]
        lam_list = [0.7, 0.2, 0.1]
        k_masks, a_list = self.forward(x_list, lam_list)
        print(f'Two_masks: {k_masks.size()}')
        print(f'a1: {a_list[0].size()}')
        print(f'a2: {a_list[1].size()}')
        print(f'a3: {a_list[2].size()}')
        print(f'alist length: {len(a_list)}')

    def test_forward_k2(self):
        x1 = torch.rand([3, 1, 32, 32], dtype=torch.float).cuda()
        x_list = [x1, x1]
        lam_list = [0.7, 0.3]
        k_masks, a_list = self.forward(x_list, lam_list)
        print(f'Two_masks: {k_masks.size()}')
        print(f'a1: {a_list[0].size()}')
        print(f'a2: {a_list[1].size()}')
        print(f'alist length: {len(a_list)}')

    def test_mix(self):
        x = torch.rand([5, 3, 32, 32], dtype=torch.float).cuda()
        y = torch.tensor([2, 4, 5, 7, 4]).cuda()
        mixed_x, _, _, _ = self.mix_data(x, y, 1)
        print(f'mixed_x: {mixed_x.size()}')

    def test_background_mix(self):
        x = torch.rand([5, 3, 32, 32], dtype=torch.float).cuda()
        y = torch.tensor([2, 4, 5, 7, 4]).cuda()
        mixed_x, y = self.mix_data_background(x, y)
        print(f'mixed_x: {mixed_x.size()}')

    def test_adj_y_mix(self):
        x = torch.rand([8, 3, 32, 32], dtype=torch.float).cuda()
        y = torch.tensor([2, 4, 5, 7, 4, 0, 1, 2]).cuda()
        mixed_x, y, _, r_list, _ = self.mix_data(x, y, 1, perm=None, return_report=False, return_mask=True)
        print(f'y_x with adj y: {r_list.size()}')

    def test_mix_upsample(self):
        x = torch.rand([8, 3, 64, 64], dtype=torch.float).cuda()
        y = torch.tensor([2, 4, 5, 7, 4, 0, 1, 2]).cuda()
        report = self.mix_data(x, y, 1, return_report=True)
        print(f'mixed_x: {report["mixed_x"].size()}')
        print(f's_map: {report["s_list"][0].size()}')

    def test_stn_downsample(self):
        x = torch.rand([8, 3, 64, 64], dtype=torch.float).cuda()
        y = torch.tensor([2, 4, 5, 7, 4, 0, 1, 2]).cuda()
        report = self.mix_data(x, y, 1, return_report=True)
        print(f'mixed_x: {report["mixed_x"].size()}')
        print(f's_map: {report["s_list"][0].size()}')


def test():
    print('==> STN')
    stn = STN(in_plane=2, k=2)
    stn.test_size()

    print('==> MixAugmenter Forward')
    from cam import CamModel
    from networks.masknet import MaskNet_K1
    from networks import get_model
    from utils import load_model

    mask_model_path = <PATH_TO_MASK_MODEL>
    pretrained_model = get_model(
                        model_name='resnet18',
                        num_class=10,
                        datamixer=None,
                        use_cuda=True,
                        data_parallel=False)
    load_model(pretrained_model, mask_model_path, location=0)
    cam_model = CamModel(pretrained_model, 0, 10, cam_method='simcam')

    k = 2
    masknet = MaskNet_K1(k, in_planes=2*k, n_channel=4).cuda()
    ma = KMixAugmentor(cam_model, masknet, k=k).cuda()
    print('  |forward k2')
    ma.test_forward_k2()
    print('  |mix k2')
    ma.test_mix()
    print('  |background mix k2')
    ma.test_background_mix()
    print('  |adjusted y mix k2')
    ma.test_adj_y_mix()
    print('  |forward k3')
    k = 3
    masknet = MaskNet_K1(k, in_planes=2*k, n_channel=4).cuda()
    ma = KMixAugmentor(cam_model, masknet, k=k).cuda()
    ma.test_forward_k3()
    print('  |mix k3')
    ma.test_mix()
    print('  |upsample k2')
    k = 2
    masknet = MaskNet_K1(k, in_planes=2*k, n_channel=4).cuda()
    ma = KMixAugmentor(cam_model, masknet, k=k, img_size=64, s_size=32).cuda()
    ma.test_mix_upsample()
    print('  |downsample k2')
    k = 2
    masknet = MaskNet_K1(k, in_planes=2*k, n_channel=4).cuda()
    ma = KMixAugmentor(cam_model, masknet, k=k, img_size=64, s_size=-1, stn_size=32).cuda()
    ma.test_stn_downsample()


if __name__ == "__main__":
    test()
