import requests
import torch
from PIL import Image
from lightning_fabric import seed_everything
from torchvision.transforms.functional import pil_to_tensor

from image_hijacks.utils import quantise_image
from image_hijacks.utils.testing import TestCase

IMG_URL = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
IMG = Image.open(requests.get(IMG_URL, stream=True).raw).convert("RGB")


class TestUtils(TestCase):
    def test_quantise_image(self):
        seed_everything(0)
        img = torch.rand((1, 3, 224, 224))
        self.assertExpectedPretty(
            img,
            """\
tensor([[[[0.496, 0.768, 0.088,  ..., 0.995, 0.681, 0.514],
          [0.067, 0.748, 0.144,  ..., 0.916, 0.300, 0.646],
          [0.523, 0.049, 0.915,  ..., 0.902, 0.016, 0.428],
          ...,
          [0.006, 0.797, 0.683,  ..., 0.750, 0.216, 0.777],
          [0.688, 0.935, 0.975,  ..., 0.962, 0.548, 0.361],
          [0.041, 0.337, 0.306,  ..., 0.317, 0.600, 0.201]],

         [[0.977, 0.772, 0.229,  ..., 0.114, 0.919, 0.236],
          [0.183, 0.680, 0.239,  ..., 0.091, 0.588, 0.838],
          [0.869, 0.825, 0.922,  ..., 0.157, 0.733, 0.988],
          ...,
          [0.286, 0.803, 0.896,  ..., 0.826, 0.898, 0.631],
          [0.339, 0.514, 0.834,  ..., 0.653, 0.068, 0.658],
          [0.651, 0.343, 0.585,  ..., 0.900, 0.990, 0.353]],

         [[0.051, 0.426, 0.563,  ..., 0.926, 0.947, 0.743],
          [0.735, 0.783, 0.230,  ..., 0.163, 0.793, 0.975],
          [0.261, 0.463, 0.641,  ..., 0.237, 0.480, 0.992],
          ...,
          [0.830, 0.432, 0.524,  ..., 0.829, 0.746, 0.327],
          [0.125, 0.141, 0.153,  ..., 0.903, 0.883, 0.966],
          [0.315, 0.638, 0.247,  ..., 0.433, 0.192, 0.853]]]])""",
        )
        self.assertExpectedPretty(
            img * 256,
            """\
tensor([[[[127.042, 196.665,  22.650,  ..., 254.623, 174.238, 131.624],
          [ 17.074, 191.408,  36.828,  ..., 234.392,  76.797, 165.489],
          [133.837,  12.580, 234.154,  ..., 231.035,   4.182, 109.552],
          ...,
          [  1.598, 203.956, 174.799,  ..., 191.980,  55.169, 198.804],
          [176.128, 239.320, 249.476,  ..., 246.160, 140.411,  92.453],
          [ 10.513,  86.150,  78.413,  ...,  81.115, 153.513,  51.468]],

         [[250.009, 197.525,  58.714,  ...,  29.093, 235.194,  60.294],
          [ 46.824, 174.135,  61.226,  ...,  23.325, 150.452, 214.588],
          [222.447, 211.123, 236.105,  ...,  40.176, 187.626, 252.940],
          ...,
          [ 73.237, 205.652, 229.314,  ..., 211.492, 229.760, 161.570],
          [ 86.908, 131.687, 213.401,  ..., 167.244,  17.417, 168.555],
          [166.759,  87.755, 149.657,  ..., 230.337, 253.383,  90.296]],

         [[ 13.123, 109.042, 144.127,  ..., 237.073, 242.364, 190.089],
          [188.118, 200.522,  58.994,  ...,  41.784, 203.129, 249.547],
          [ 66.921, 118.587, 164.007,  ...,  60.799, 122.878, 253.948],
          ...,
          [212.581, 110.481, 134.270,  ..., 212.243, 190.888,  83.601],
          [ 32.025,  36.134,  39.228,  ..., 231.063, 226.065, 247.294],
          [ 80.693, 163.412,  63.359,  ..., 110.915,  49.039, 218.368]]]])""",
        )
        quantised_img = quantise_image(img)
        self.assertExpectedPretty(
            quantised_img,
            """\
tensor([[[[0.498, 0.769, 0.090,  ..., 0.996, 0.682, 0.514],
          [0.067, 0.749, 0.145,  ..., 0.914, 0.298, 0.647],
          [0.521, 0.051, 0.914,  ..., 0.902, 0.016, 0.427],
          ...,
          [0.008, 0.796, 0.682,  ..., 0.749, 0.216, 0.776],
          [0.686, 0.933, 0.977,  ..., 0.961, 0.549, 0.361],
          [0.039, 0.337, 0.306,  ..., 0.318, 0.600, 0.200]],

         [[0.977, 0.772, 0.227,  ..., 0.114, 0.917, 0.235],
          [0.184, 0.678, 0.239,  ..., 0.090, 0.588, 0.839],
          [0.871, 0.824, 0.921,  ..., 0.157, 0.733, 0.988],
          ...,
          [0.286, 0.804, 0.894,  ..., 0.828, 0.898, 0.631],
          [0.341, 0.514, 0.835,  ..., 0.655, 0.067, 0.659],
          [0.651, 0.341, 0.584,  ..., 0.898, 0.988, 0.353]],

         [[0.051, 0.427, 0.565,  ..., 0.925, 0.945, 0.741],
          [0.733, 0.784, 0.231,  ..., 0.165, 0.792, 0.977],
          [0.263, 0.463, 0.639,  ..., 0.239, 0.479, 0.992],
          ...,
          [0.832, 0.431, 0.525,  ..., 0.828, 0.745, 0.325],
          [0.125, 0.141, 0.153,  ..., 0.902, 0.882, 0.965],
          [0.314, 0.639, 0.247,  ..., 0.431, 0.192, 0.855]]]],
       dtype=torch.float16)""",
        )
        self.assertExpectedPretty(
            quantised_img * 255,
            """\
tensor([[[[127., 196.,  23.,  ..., 254., 174., 131.],
          [ 17., 191.,  37.,  ..., 233.,  76., 165.],
          [133.,  13., 233.,  ..., 230.,   4., 109.],
          ...,
          [  2., 203., 174.,  ..., 191.,  55., 198.],
          [175., 238., 249.,  ..., 245., 140.,  92.],
          [ 10.,  86.,  78.,  ...,  81., 153.,  51.]],

         [[249., 197.,  58.,  ...,  29., 234.,  60.],
          [ 47., 173.,  61.,  ...,  23., 150., 214.],
          [222., 210., 235.,  ...,  40., 187., 252.],
          ...,
          [ 73., 205., 228.,  ..., 211., 229., 161.],
          [ 87., 131., 213.,  ..., 167.,  17., 168.],
          [166.,  87., 149.,  ..., 229., 252.,  90.]],

         [[ 13., 109., 144.,  ..., 236., 241., 189.],
          [187., 200.,  59.,  ...,  42., 202., 249.],
          [ 67., 118., 163.,  ...,  61., 122., 253.],
          ...,
          [212., 110., 134.,  ..., 211., 190.,  83.],
          [ 32.,  36.,  39.,  ..., 230., 225., 246.],
          [ 80., 163.,  63.,  ..., 110.,  49., 218.]]]], dtype=torch.float16)""",
        )
