import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class AdvProgram(nn.Module):
    def __init__(self, out_size, mask, init='zero', normalize=None):
        super(AdvProgram, self).__init__()
        assert mask.shape[0] == mask.shape[1]
        in_size = mask.shape[0]
        self.out_size = out_size
        if init == "zero":
            self.program = torch.nn.Parameter(data=torch.zeros(3, out_size, out_size))
        elif init == "randn":
            self.program = torch.nn.Parameter(data=torch.randn(3, out_size, out_size))
        else:
            raise ValueError("init method not supported")
        self.normalize = normalize

        self.l_pad = int((out_size - in_size + 1) / 2)
        self.r_pad = int((out_size - in_size) / 2)

        mask = np.repeat(np.expand_dims(mask, 0), repeats=3, axis=0)
        mask = torch.Tensor(mask)
        self.register_buffer("mask", F.pad(mask, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=1))

    def forward(self, x):
        x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(
            self.program) * self.mask
        if self.normalize is not None:
            x = self.normalize(x)
        return x


class AdvProgramInterPad(nn.Module):
    def __init__(self, feature_map_out_size, feature_map, feature_map_num, init='zero'):
        super(AdvProgramInterPad, self).__init__()
        assert feature_map.shape[0] == feature_map.shape[1]
        in_size = feature_map.shape[0]
        self.out_size = feature_map_out_size
        if init == "zero":
            self.program = torch.nn.Parameter(
                data=torch.zeros(feature_map_num, feature_map_out_size, feature_map_out_size))
        elif init == "randn":
            self.program = torch.nn.Parameter(
                data=torch.randn(feature_map_num, feature_map_out_size, feature_map_out_size))
        else:
            raise ValueError("init method not supported")

        self.l_pad = int((feature_map_out_size - in_size + 1) / 2)
        self.r_pad = int((feature_map_out_size - in_size) / 2)

        feature_map = np.repeat(np.expand_dims(feature_map, 0), repeats=feature_map_num, axis=0)
        feature_map = torch.Tensor(feature_map)
        self.register_buffer("mask", F.pad(feature_map, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=1))

        print(
            f"Before Prompting: [{feature_map_num}, {in_size}, {in_size}]; After Prompting: [{feature_map_num}, {feature_map_out_size}, {feature_map_out_size}]")
        print(
            f"Trainable Params: {feature_map_num * (feature_map_out_size ** 2 - in_size ** 2)}, namely {feature_map_num * (feature_map_out_size ** 2 - in_size ** 2) / (1024 ** 2)}M")

    def forward(self, x):
        x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + self.program * self.mask
        return x


class AdvProgramInterAdd(nn.Module):
    def __init__(self, feature_map, feature_map_num, init='zero'):
        super(AdvProgramInterAdd, self).__init__()
        if feature_map.shape[0] == feature_map.shape[1]:
            self.out_size = feature_map.shape[0]
            if init == "zero":
                self.shift = torch.nn.Parameter(data=torch.zeros(feature_map_num, self.out_size, self.out_size))
            elif init == "randn":
                self.shift = torch.nn.Parameter(data=torch.randn(feature_map_num, self.out_size, self.out_size))
            else:
                raise ValueError("init method not supported")
            print(f"Prompting: [{feature_map_num}, {self.out_size}, {self.out_size}]")
            # print(f"Trainable Params: {feature_map_num * self.out_size ** 2}, namely {feature_map_num * self.out_size ** 2 / (1024 ** 2)}M")
        elif feature_map_num == 0:
            if init == "zero":
                self.shift = torch.nn.Parameter(data=torch.zeros(feature_map.shape[0], feature_map.shape[1]))
            elif init == "randn":
                self.shift = torch.nn.Parameter(data=torch.randn(feature_map.shape[0], feature_map.shape[1]))
            else:
                raise ValueError("init method not supported")
            print(f"Prompting: [{feature_map.shape[0]}, {feature_map.shape[1]}]")
            # print(f"Trainable Params: {feature_map.shape[0] * feature_map.shape[1]}")
        else:
            raise ValueError("init method not supported")

    def forward(self, x):
        x = x + self.shift
        return x


class AdvProgramInterAddSimple(nn.Module):
    def __init__(self, feature_map, feature_map_num, init='zero'):
        super(AdvProgramInterAddSimple, self).__init__()
        if feature_map.shape[0] == feature_map.shape[1]:
            self.out_size = feature_map.shape[0]
            if init == "zero":
                # self.scale = torch.nn.Parameter(data=torch.ones(self.out_size, self.out_size))
                self.shift = torch.nn.Parameter(data=torch.zeros(self.out_size, self.out_size))
            elif init == "randn":
                # self.scale = torch.nn.Parameter(data=torch.randn(self.out_size, self.out_size))
                self.shift = torch.nn.Parameter(data=torch.randn(self.out_size, self.out_size))
            else:
                raise ValueError("init method not supported")
            print(f"Prompting: [{self.out_size}, {self.out_size}]")
            print(f"Trainable Params: {2 * self.out_size ** 2}, namely {2 * self.out_size ** 2 / (1024 ** 2)}M")
        elif feature_map_num == 0:
            if init == "zero":
                # self.scale = torch.nn.Parameter(data=torch.zeros(feature_map.shape[1]))
                self.shift = torch.nn.Parameter(data=torch.zeros(feature_map.shape[1]))
            elif init == "randn":
                # self.scale = torch.nn.Parameter(data=torch.randn(feature_map.shape[1]))
                self.shift = torch.nn.Parameter(data=torch.randn(feature_map.shape[1]))
            else:
                raise ValueError("init method not supported")
            print(f"Prompting: [{feature_map.shape[1]}]")
            print(f"Trainable Params: {feature_map.shape[1]}")
        else:
            raise ValueError("init method not supported")

    def forward(self, x):
        # x = x * self.scale + self.shift
        x = x + self.shift
        return x


class ExpansiveVisualPrompt(nn.Module):
    def __init__(self, out_size, mask, init = 'zero', normalize=None):
        super(ExpansiveVisualPrompt, self).__init__()
        assert mask.shape[0] == mask.shape[1]
        self.in_size = mask.shape[0]
        self.out_size = out_size
        if init == "zero":
            self.program = torch.nn.Parameter(data=torch.zeros(3, out_size, out_size))
        elif init == "randn":
            self.program = torch.nn.Parameter(data=torch.randn(3, out_size, out_size))
        else:
            raise ValueError("init method not supported")
        self.normalize = normalize

        self.l_pad = int((out_size - self.in_size + 1)/2)
        self.r_pad = int((out_size - self.in_size)/2)

        mask = np.repeat(np.expand_dims(mask, 0), repeats=3, axis=0)
        mask = torch.Tensor(mask)
        self.register_buffer("mask", F.pad(mask, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=1))

    def forward(self, x):
        # Resize input x to in_size
        x = F.interpolate(x, size=(self.in_size, self.in_size), mode='bilinear', align_corners=False)
        x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(self.program) * self.mask
        if self.normalize is not None:
            x = self.normalize(x)
        return x


class VisualPrompt(nn.Module):
    def __init__(self, size, pad):
        super(VisualPrompt, self).__init__()

        self.size = size
        self.program = torch.nn.Parameter(data=torch.zeros(3, size, size))

        if size > 2 * pad:
            mask = torch.zeros(3, size - 2 * pad, size - 2 * pad)
            self.register_buffer("mask", F.pad(mask, [pad for _ in range(4)], value=1))
        elif size == 2 * pad:
            mask = torch.ones(3, size, size)
            self.register_buffer("mask", mask)
        else:
            raise ValueError("Pad Should Not Exceed Half Of Size")

    def forward(self, x):
        x += self.program * self.mask
        return x

class VPT(nn.Module):
    def __init__(self, embedding_shape, prompt_length=5, init='zero'):
        super(VPT, self).__init__()
        self.embedding_shape = embedding_shape
        if init == "zero":
            self.program = torch.nn.Parameter(data=torch.zeros(prompt_length, embedding_shape[-1]))
        elif init == "randn":
            self.program = torch.nn.Parameter(data=torch.randn(prompt_length, embedding_shape[-1]))
        else:
            raise ValueError("init method not supported")
        # print(f"Prompting: [{prompt_length}, {embedding_shape[-1]}]")
        # print(f"Trainable Params: {prompt_length * embedding_shape[-1]}")

    def forward(self, x):
        # expand the dimension of the self.program to [batch_size, prompt_length, embedding_size]
        # and insert the prompt after [CLS] embedding
        # print("input shape:", x.shape)
        x = torch.cat([x[:, :1, :], self.program.unsqueeze(0).repeat(x.shape[0], 1, 1), x[:, -self.embedding_shape[-2]+1:, :]], dim=1)
        # print("output shape:", x.shape)
        return x


if __name__ == "__main__":
    prompt = VPT((128, 197, 192), 5)
    input = torch.randn((128, 197, 192))
    res = prompt(input)
    print(res.shape) # shows [128, 202, 192]
