import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import logging
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.functional as F


logger = logging.getLogger(__name__)

class BarVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: bar\n')
        super(BarVisualPrompt, self).__init__()
        width = args.bar_width
        height = args.bar_height
        output_size = args.output_size
        input_size = args.input_size
        mask_size = args.mask_size
        pad_size = int((args.output_size - args.mask_size)/2)
        self.normalize=normalize
        self.network = args.network

        self.l_pad = int((output_size-input_size+1)/2)
        self.r_pad = int((output_size-input_size)/2)

        init_methods = args.init_method.split(',')
        self.left_bar = torch.nn.Parameter(torch.empty(3, height, width))
        self.get_init(init_methods[0], self.left_bar)
        self.right_bar = torch.nn.Parameter(torch.empty(3, width, height))
        self.get_init(init_methods[1], self.right_bar)
        self.program = torch.bmm(self.left_bar, self.right_bar)

        if output_size > 2*pad_size:
            mask = torch.zeros(3, mask_size, mask_size)
            self.register_buffer("mask", F.pad(mask, [pad_size for _ in range(4)], value=1))
        elif output_size == 2*pad_size:
            mask = torch.ones(3, output_size, output_size)
            self.register_buffer("mask", mask)
        else:
            raise ValueError("Pad Should Not Exceed Half Of Output Size")
        logger.info(f'width: {args.bar_width}, height: {args.bar_height}, output size: {args.output_size}, input size: {args.input_size}, mask size: {args.mask_size}')


    def get_init(self, init_method, params):
        if init_method == 'zero':
            params.data.fill_(0)
        elif init_method == 'random':
            params.data.normal_(0, 1)
        elif init_method == 'xavier':
            torch.nn.init.xavier_uniform_(params)
        elif init_method == 'kaiming':
            torch.nn.init.kaiming_uniform_(params, nonlinearity='relu')
        elif init_method == 'uniform':
            torch.nn.init.uniform_(params, a=-0.1, b=0.1)
        elif init_method == 'normal':
            torch.nn.init.normal_(params, mean=0.0, std=0.01)


    def forward(self, x):
        self.program = torch.bmm(self.left_bar, self.right_bar)
        x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(self.program) * self.mask
        x = x.clamp(0, 1)
        if self.normalize is not None:
            x = self.normalize(x)
        return x


class PatchVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: patch\n')
        super(PatchVisualPrompt, self).__init__()
        width = args.bar_width
        height = args.bar_height
        self.output_size = args.output_size
        self.normalize=normalize
        self.network = args.network
        self.patch_size = args.patch_size
        self.pad_size = args.pad_size
        self.patch_num = self.output_size // self.patch_size
        self.mask_patch_size = self.patch_size - self.pad_size
        self.mask_size = self.patch_num * self.mask_patch_size
        self.mask_pad_size = self.pad_size

        self.mask_l_pad = int((self.mask_pad_size+1)/2)
        self.mask_r_pad = int(self.mask_pad_size/2)
        self.program = torch.nn.Parameter(data=torch.zeros(3, self.patch_size, self.patch_size))
        # self.program = self.program.repeat(1, self.patch_num, self.patch_num)
        mask = torch.ones(3, self.output_size, self.output_size)
        self.register_buffer("mask", mask)

        logger.info(f'width: {args.bar_width}, height: {args.bar_height}, output size: {args.output_size}, input size: {args.input_size}, patch size: {args.patch_size}, patch num: {self.patch_num}, mask l pad: {self.mask_l_pad}, mask r pad: {self.mask_r_pad}')

    def get_init(self, init_method, params):
        if init_method == 'zero':
            params.data.fill_(0)
        elif init_method == 'random':
            params.data.normal_(0, 1)
        elif init_method == 'xavier':
            torch.nn.init.xavier_uniform_(params)
        elif init_method == 'kaiming':
            torch.nn.init.kaiming_uniform_(params, nonlinearity='relu')
        elif init_method == 'uniform':
            torch.nn.init.uniform_(params, a=-0.1, b=0.1)
        elif init_method == 'normal':
            torch.nn.init.normal_(params, mean=0.0, std=0.01)

    def patch_mask(self, x, value=1):
        batch_size = x.size(0)
        x_patches = x.unfold(2, self.mask_patch_size, self.mask_patch_size).unfold(3, self.mask_patch_size, self.mask_patch_size)
        x_patches = x_patches.contiguous().view(batch_size, 3, -1, self.mask_patch_size, self.mask_patch_size)
        padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
        padded_x = padded_x.view(batch_size, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)
        padded_x = padded_x.contiguous().view(batch_size*3, -1, self.patch_size * self.patch_size)
        padded_x = padded_x.permute(0, 2, 1)  # Adjust dimensions for batch processing
        recovered_x = F.fold(
            padded_x,  # Flatten for fold
            output_size=(self.output_size, self.output_size),  # Target size
            kernel_size=(self.patch_size, self.patch_size),  # Patch size
            stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
        )
        recovered_x = recovered_x.view(batch_size, 3, self.output_size, self.output_size)
        return recovered_x

    def forward(self, x):
        program_patches = self.program.repeat(1, self.patch_num, self.patch_num)
        x = x + program_patches
        if self.normalize is not None:
            x = self.normalize(x)
        return x


class FullVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: full\n')
        super(FullVisualPrompt, self).__init__()
        width = args.bar_width
        height = args.bar_height
        self.output_size = args.output_size
        self.normalize=normalize
        self.network = args.network
        self.patch_size = args.patch_size
        self.pad_size = args.pad_size
        self.patch_num = self.output_size // self.patch_size
        self.mask_patch_size = self.patch_size - self.pad_size
        self.mask_size = self.patch_num * self.mask_patch_size
        self.mask_pad_size = self.pad_size

        self.mask_l_pad = int((self.mask_pad_size+1)/2)
        self.mask_r_pad = int(self.mask_pad_size/2)
        self.program = torch.nn.Parameter(data=torch.zeros(3, self.output_size, self.output_size)) 
        mask = torch.ones(3, self.output_size, self.output_size)
        self.register_buffer("mask", mask)

        logger.info(f'width: {args.bar_width}, height: {args.bar_height}, output size: {args.output_size}, input size: {args.input_size}, patch size: {args.patch_size}, patch num: {self.patch_num}, mask l pad: {self.mask_l_pad}, mask r pad: {self.mask_r_pad}')

    def get_init(self, init_method, params):
        if init_method == 'zero':
            params.data.fill_(0)
        elif init_method == 'random':
            params.data.normal_(0, 1)
        elif init_method == 'xavier':
            torch.nn.init.xavier_uniform_(params)
        elif init_method == 'kaiming':
            torch.nn.init.kaiming_uniform_(params, nonlinearity='relu')
        elif init_method == 'uniform':
            torch.nn.init.uniform_(params, a=-0.1, b=0.1)
        elif init_method == 'normal':
            torch.nn.init.normal_(params, mean=0.0, std=0.01)

    def patch_mask(self, x, value=1):
        batch_size = x.size(0)
        x_patches = x.unfold(2, self.mask_patch_size, self.mask_patch_size).unfold(3, self.mask_patch_size, self.mask_patch_size)
        x_patches = x_patches.contiguous().view(batch_size, 3, -1, self.mask_patch_size, self.mask_patch_size)
        padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
        padded_x = padded_x.view(batch_size, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)
        padded_x = padded_x.contiguous().view(batch_size*3, -1, self.patch_size * self.patch_size)
        padded_x = padded_x.permute(0, 2, 1)  # Adjust dimensions for batch processing
        recovered_x = F.fold(
            padded_x,  # Flatten for fold
            output_size=(self.output_size, self.output_size),  # Target size
            kernel_size=(self.patch_size, self.patch_size),  # Patch size
            stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
        )
        recovered_x = recovered_x.view(batch_size, 3, self.output_size, self.output_size)
        return recovered_x

    def forward(self, x):
        x = x + self.program
        if self.normalize is not None:
            x = self.normalize(x)
        return x


class BarFullVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: barfull\n')
        super(BarFullVisualPrompt, self).__init__()
        width = args.bar_width
        height = args.bar_height
        self.output_size = args.output_size
        self.normalize=normalize
        self.network = args.network
        self.patch_size = args.patch_size
        self.pad_size = args.pad_size
        self.patch_num = self.output_size // self.patch_size
        self.mask_patch_size = self.patch_size - self.pad_size
        self.mask_size = self.patch_num * self.mask_patch_size
        self.mask_pad_size = self.pad_size

        self.mask_l_pad = int((self.mask_pad_size+1)/2)
        self.mask_r_pad = int(self.mask_pad_size/2)

        init_methods = args.init_method.split(',')
        self.left_bar = torch.nn.Parameter(torch.empty(3, height, width))
        self.get_init(init_methods[0], self.left_bar)
        self.right_bar = torch.nn.Parameter(torch.empty(3, width, height))
        self.get_init(init_methods[1], self.right_bar)
        self.program = torch.bmm(self.left_bar, self.right_bar)
        mask = torch.ones(3, self.output_size, self.output_size)
        self.register_buffer("mask", mask)

        logger.info(f'width: {args.bar_width}, height: {args.bar_height}, output size: {args.output_size}, input size: {args.input_size}, patch size: {args.patch_size}, patch num: {self.patch_num}, mask l pad: {self.mask_l_pad}, mask r pad: {self.mask_r_pad}')

    def get_init(self, init_method, params):
        if init_method == 'zero':
            params.data.fill_(0)
        elif init_method == 'random':
            params.data.normal_(0, 1)
        elif init_method == 'xavier':
            torch.nn.init.xavier_uniform_(params)
        elif init_method == 'kaiming':
            torch.nn.init.kaiming_uniform_(params, nonlinearity='relu')
        elif init_method == 'uniform':
            torch.nn.init.uniform_(params, a=-0.1, b=0.1)
        elif init_method == 'normal':
            torch.nn.init.normal_(params, mean=0.0, std=0.01)

    def patch_mask(self, x, value=1):
        batch_size = x.size(0)
        x_patches = x.unfold(2, self.mask_patch_size, self.mask_patch_size).unfold(3, self.mask_patch_size, self.mask_patch_size)
        x_patches = x_patches.contiguous().view(batch_size, 3, -1, self.mask_patch_size, self.mask_patch_size)
        padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
        padded_x = padded_x.view(batch_size, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)
        padded_x = padded_x.contiguous().view(batch_size*3, -1, self.patch_size * self.patch_size)
        padded_x = padded_x.permute(0, 2, 1)  # Adjust dimensions for batch processing
        recovered_x = F.fold(
            padded_x,  # Flatten for fold
            output_size=(self.output_size, self.output_size),  # Target size
            kernel_size=(self.patch_size, self.patch_size),  # Patch size
            stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
        )
        recovered_x = recovered_x.view(batch_size, 3, self.output_size, self.output_size)
        return recovered_x

    def forward(self, x):
        self.program = torch.bmm(self.left_bar, self.right_bar)
        x = x + self.program
        if self.normalize is not None:
            x = self.normalize(x)
        return x


class PatchBarVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: patchbar\n')
        super(PatchBarVisualPrompt, self).__init__()
        width = args.bar_width
        height = args.bar_height
        self.output_size = args.output_size
        patch_size = args.patch_size
        self.normalize=normalize
        self.network = args.network
        self.patch_size = patch_size
        self.pad_size = args.pad_size
        self.patch_num = self.output_size // self.patch_size
        self.mask_patch_size = self.patch_size - self.pad_size
        self.mask_size = self.patch_num * self.mask_patch_size
        self.mask_pad_size = self.pad_size

        self.mask_l_pad = int((self.mask_pad_size+1)/2)
        self.mask_r_pad = int(self.mask_pad_size/2)

        init_methods = args.init_method.split(',')
        self.left_bar = torch.nn.Parameter(torch.empty(3, height, width))
        self.get_init(init_methods[0], self.left_bar)
        self.right_bar = torch.nn.Parameter(torch.empty(3, width, height))
        self.get_init(init_methods[1], self.right_bar)
        self.program = torch.bmm(self.left_bar, self.right_bar)

        mask = torch.zeros(1, 3, self.mask_size, self.mask_size)
        padded_mask = self.patch_mask(mask, value=1).squeeze(0)
        image_to_plot = padded_mask.detach().cpu().numpy()
        image_to_plot = image_to_plot.transpose(1, 2, 0)
        plt.imshow(image_to_plot)
        plt.axis('off')  # Turn off the axis labels
        plt.show()
        plt.savefig(f'patch_bar_mask.png')
        self.register_buffer("mask", padded_mask)
        logger.info(f'width: {args.bar_width}, height: {args.bar_height}, output size: {args.output_size}, input size: {args.input_size}, patch size: {args.patch_size}, patch num: {self.patch_num}, mask l pad: {self.mask_l_pad}, mask r pad: {self.mask_r_pad}')


    def get_init(self, init_method, params):
        if init_method == 'zero':
            params.data.fill_(0)
        elif init_method == 'random':
            params.data.normal_(0, 1)
        elif init_method == 'xavier':
            torch.nn.init.xavier_uniform_(params)
        elif init_method == 'kaiming':
            torch.nn.init.kaiming_uniform_(params, nonlinearity='relu')
        elif init_method == 'uniform':
            torch.nn.init.uniform_(params, a=-0.1, b=0.1)
        elif init_method == 'normal':
            torch.nn.init.normal_(params, mean=0.0, std=0.01)


    # def patch_mask(self, x, value=1):
    #     x_patches = x.unfold(1, self.mask_patch_size, self.mask_patch_size).unfold(2, self.mask_patch_size, self.mask_patch_size)
    #     x_patches = x_patches.contiguous().view(3, -1, self.mask_patch_size, self.mask_patch_size)
    #     padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
    #     padded_x = padded_x.view(3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)

    #     padded_x = padded_x.contiguous().view(3, -1, self.patch_size * self.patch_size)
    #     padded_x = padded_x.permute(0, 2, 1)
    #     recovered_mask = F.fold(
    #         padded_x,  # Prepare for fold
    #         output_size=(self.output_size, self.output_size),                          # Target size
    #         kernel_size=(self.patch_size, self.patch_size),   # Patch size
    #         stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
    #     )
    #     padded_x = recovered_mask.squeeze(1)
    #     return padded_x


    # def patch_mask(self, x, value=1):
    #     batch_size = x.size(0)
    #     x_patches = x.unfold(2, self.mask_patch_size, self.mask_patch_size).unfold(3, self.mask_patch_size, self.mask_patch_size)
    #     x_patches = x_patches.contiguous().view(batch_size, 3, -1, self.mask_patch_size, self.mask_patch_size)
    #     padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
    #     padded_x = padded_x.view(batch_size, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)
    #     padded_x = padded_x.contiguous().view(batch_size*3, -1, self.patch_size * self.patch_size)
    #     padded_x = padded_x.permute(0, 2, 1)  # Adjust dimensions for batch processing
    #     recovered_x = F.fold(
    #         padded_x,  # Flatten for fold
    #         output_size=(self.output_size, self.output_size),  # Target size
    #         kernel_size=(self.patch_size, self.patch_size),  # Patch size
    #         stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
    #     )
    #     recovered_x = recovered_x.view(batch_size, 3, self.output_size, self.output_size)
    #     return recovered_x

    def patch_mask(self, x, value=1):
        batch_size = x.size(0)
        x_patches = x.unfold(2, self.mask_patch_size, self.mask_patch_size).unfold(3, self.mask_patch_size, self.mask_patch_size)
        x_patches = x_patches.contiguous().view(batch_size, 3, -1, self.mask_patch_size, self.mask_patch_size)
        padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
        padded_x = padded_x.view(batch_size, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)
        padded_x = padded_x.contiguous().view(batch_size*3, -1, self.patch_size * self.patch_size)
        padded_x = padded_x.permute(0, 2, 1)  # Adjust dimensions for batch processing
        recovered_x = F.fold(
            padded_x,  # Flatten for fold
            output_size=(self.output_size, self.output_size),  # Target size
            kernel_size=(self.patch_size, self.patch_size),  # Patch size
            stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
        )
        recovered_x = recovered_x.view(batch_size, 3, self.output_size, self.output_size)
        return recovered_x

    def forward(self, x):
        self.program = torch.bmm(self.left_bar, self.right_bar)
        x = x + self.program * self.mask
        # for i in range(10):
        #     image_to_plot = x[i].detach().cpu().numpy()
        #     image_to_plot = image_to_plot.transpose(1, 2, 0)
        #     plt.imshow(image_to_plot)
        #     plt.axis('off')  # Turn off the axis labels
        #     plt.show()
        #     plt.savefig(f'patch_bar_mask_{i}.png')
        if self.normalize is not None:
            x = self.normalize(x)
        return x
    

class PatchPadVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: patchpad\n')
        super(PatchPadVisualPrompt, self).__init__()
        width = args.bar_width
        height = args.bar_height
        self.output_size = args.output_size
        self.normalize=normalize
        self.network = args.network
        self.patch_size = args.patch_size
        self.pad_size = args.pad_size
        self.patch_num = self.output_size // self.patch_size
        self.mask_patch_size = self.patch_size - self.pad_size
        self.mask_size = self.patch_num * self.mask_patch_size
        self.mask_pad_size = self.pad_size

        self.mask_l_pad = int((self.mask_pad_size+1)/2)
        self.mask_r_pad = int(self.mask_pad_size/2)

        self.program = torch.nn.Parameter(data=torch.zeros(3, self.output_size, self.output_size)) 

        mask = torch.zeros(1, 3, self.mask_size, self.mask_size)
        padded_mask = self.patch_mask(mask, value=1).squeeze(0)
        image_to_plot = padded_mask.detach().cpu().numpy()
        image_to_plot = image_to_plot.transpose(1, 2, 0)
        plt.imshow(image_to_plot)
        plt.axis('off')  # Turn off the axis labels
        plt.show()
        plt.savefig(f'patch_bar_mask.png')
        self.register_buffer("mask", padded_mask)
        logger.info(f'width: {args.bar_width}, height: {args.bar_height}, output size: {args.output_size}, input size: {args.input_size}, patch size: {args.patch_size}, patch num: {self.patch_num}, mask l pad: {self.mask_l_pad}, mask r pad: {self.mask_r_pad}')


    def get_init(self, init_method, params):
        if init_method == 'zero':
            params.data.fill_(0)
        elif init_method == 'random':
            params.data.normal_(0, 1)
        elif init_method == 'xavier':
            torch.nn.init.xavier_uniform_(params)
        elif init_method == 'kaiming':
            torch.nn.init.kaiming_uniform_(params, nonlinearity='relu')
        elif init_method == 'uniform':
            torch.nn.init.uniform_(params, a=-0.1, b=0.1)
        elif init_method == 'normal':
            torch.nn.init.normal_(params, mean=0.0, std=0.01)

    def patch_mask(self, x, value=1):
        batch_size = x.size(0)
        x_patches = x.unfold(2, self.mask_patch_size, self.mask_patch_size).unfold(3, self.mask_patch_size, self.mask_patch_size)
        x_patches = x_patches.contiguous().view(batch_size, 3, -1, self.mask_patch_size, self.mask_patch_size)
        padded_x = F.pad(x_patches, (self.mask_l_pad, self.mask_r_pad, self.mask_l_pad, self.mask_r_pad), value=value)
        padded_x = padded_x.view(batch_size, 3, self.patch_num, self.patch_num, self.patch_size, self.patch_size)
        padded_x = padded_x.contiguous().view(batch_size*3, -1, self.patch_size * self.patch_size)
        padded_x = padded_x.permute(0, 2, 1)  # Adjust dimensions for batch processing
        recovered_x = F.fold(
            padded_x,  # Flatten for fold
            output_size=(self.output_size, self.output_size),  # Target size
            kernel_size=(self.patch_size, self.patch_size),  # Patch size
            stride=(self.patch_size, self.patch_size)  # Stride is equal to patch size
        )
        recovered_x = recovered_x.view(batch_size, 3, self.output_size, self.output_size)
        return recovered_x

    def forward(self, x):
        x = self.patch_mask(x, value=0) + self.program * self.mask
        if self.normalize is not None:
            x = self.normalize(x)
        return x
    

class PadVisualPrompt(nn.Module):
    def __init__(self, args, normalize=None):
        logger.info('prompt method: pad\n')
        super(PadVisualPrompt, self).__init__()
        mask_size = args.mask_size
        pad_size = (args.output_size - args.mask_size)//2
        output_size = args.output_size
        input_size = args.input_size
        self.l_pad = int((output_size-input_size+1)/2)
        self.r_pad = int((output_size-input_size)/2)
        self.normalize=normalize
        self.program = torch.nn.Parameter(data=torch.zeros(3, output_size, output_size)) 

        if mask_size > 0:
            mask = torch.zeros(3, mask_size, mask_size)
            self.register_buffer("mask", F.pad(mask, [pad_size for _ in range(4)], value=1))
        elif mask_size == 0:
            mask = torch.ones(3, output_size, output_size)
            self.register_buffer("mask", mask)
        else:
            raise ValueError("Pad Should Not Exceed Half Of Output Size")
        logger.info(f'input size: {args.input_size}, output size: {args.output_size}, pad size: {pad_size}')

    def forward(self, x):
        x = F.pad(x, (self.l_pad, self.r_pad, self.l_pad, self.r_pad), value=0) + torch.sigmoid(self.program) * self.mask
        x = x.clamp(0, 1)
        if self.normalize is not None:
            x = self.normalize(x)
        return x

