import cv2
import numpy as np
import torch
import torchvision.transforms.v2 as transforms
from PIL import Image
from matplotlib import pyplot as plt

from lighting.relight import rescale_image
from lighting.relight_model import CroCoDecode, RelightModule


def save_image(tensor, path):
    tensor = rescale_image(tensor).squeeze(0).cpu().detach()
    cv2.imwrite(path, cv2.cvtColor((tensor.numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))


def swap_lighting(croco_relight: RelightModule):
    transform = transforms.Compose([
        transforms.ToImage(),
        transforms.Resize(448),
        transforms.CenterCrop(448),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    # blank_transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

    # blank_img = torch.zeros(1, 3, 448, 448, dtype=torch.float32, device=device)
    # img1 = blank_transform(blank_img)
    # img2 = transform(Image.open('data/relighting/beach-clouds-coast-417211.jpg').convert('RGB')).unsqueeze(0).to(device)
    # Replace with your own image
    img1 = transform(Image.open('./data/Input1.png').convert('RGB')).unsqueeze(0).to(device)
    img2 = transform(Image.open('./data/Input2.png').convert('RGB')).unsqueeze(0).to(device)
    img_info = {'height': 448, 'width': 448}

    with torch.no_grad():
        img1_relit, img2_relit, static, static_pos, dyn, _ = croco_relight(img1, img2, do_tiling=False)

    # with torch.no_grad():
    #     img_feat, img_pos, _ = croco_relight.croco._encode_image(img2, False, False)
    #     img_static, img_dyn, _ = croco_relight.lighting_extractor(img_feat, img_pos)
    #     blank_static = torch.zeros_like(img_static)
    #     blank_dyn = torch.zeros_like(img_dyn)
    #     lighting_feat = croco_relight.lighting_entangler(blank_static, img_pos, img_dyn)
    #     intrinsics_feat = croco_relight.lighting_entangler(img_static, img_pos, blank_dyn)
    #     lighting_img = croco_relight.croco.decode(lighting_feat, img_pos, img_info)
    #     intrinsics_img = croco_relight.croco.decode(intrinsics_feat, img_pos, img_info)


    save_image(img1_relit, "./data/out/compression/lighting3.png")
    save_image(img2_relit, "./data/out/compression/intrinsics3.png")


def remove_shadow(croco_relight, mapper_model):
    transform = transforms.Compose([
        transforms.ToImage(),
        transforms.Resize(448),
        # transforms.CenterCrop(448),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    img = transform(Image.open('./data/diagram/shadow/srd_271_input.jpg').convert('RGB')).unsqueeze(0).to(device)

    with torch.no_grad():
        shadow_free_img = croco_relight.apply_mapper(img, mapper_model)

    # save_image(shadow_free_img, './data/diagram/shadow/shadow6_free.jpg')

    plt.figure(figsize=(12, 8))
    plt.subplot(1, 2, 1)
    plt.title('Input Image')
    plt.imshow(rescale_image(img).squeeze(0).cpu().permute(1, 2, 0))
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title('Shadow Removed')
    plt.imshow(rescale_image(shadow_free_img).squeeze(0).cpu().permute(1, 2, 0))
    plt.axis('off')

    plt.tight_layout()
    plt.show()


if __name__ == '__main__':
    device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu')
    ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
    croco_decode = CroCoDecode(**ckpt.get('croco_kwargs', {})).to(device)
    # croco_decode.load_state_dict(ckpt['model'])
    croco_decode.setup()
    # decode_ckpt = torch.load('lighting/models/croco_relight_pretrained3.pth', 'cpu')
    # croco_decode.load_state_dict(decode_ckpt)
    croco_relight = RelightModule(croco_decode).to(device)
    relight_ckpt = torch.load('lighting/models/croco_relight_from_pretrain_all.pth', 'cpu')
    croco_relight.load_state_dict(relight_ckpt)
    croco_decode.eval()
    croco_relight.eval()

    # mapper_model = LightingEntangler(patch_size=croco_relight.croco.enc_embed_dim, extractor_depth=8,
    #                                    rope=croco_relight.croco.rope).to(device)
    # mapper_model.load_state_dict(torch.load('lighting/models/shadow_mapper_all_intrinsic2.pth', 'cpu'))
    # mapper_model.eval()

    swap_lighting(croco_relight)
    # remove_shadow(croco_relight, mapper_model)
