
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

import torch
import random
import numpy as np
import torchvision
from torchvision import transforms
import argparse
from torch.utils.data import DataLoader
from models.model import MCUCoder
from tqdm import tqdm
from dahuffman import HuffmanCodec
import matplotlib.pyplot as plt  # Import matplotlib
from torchvision import datasets, transforms
import torch.nn.functional as F
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import tensorflow as tf



def encode_TFLite(model_path, X):
    x_data = np.copy(X.to('cpu').numpy()) # the function quantizes the input, so we must make a copy
    # Initialize the TFLite interpreter
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    # Inputs will be quantized
    input_scale, input_zero_point = input_details["quantization"]
    if (input_scale, input_zero_point) != (0.0, 0):
        x_data = x_data / input_scale + input_zero_point
        x_data = x_data.astype(input_details["dtype"])
    # Invoke the interpreter
    predictions = np.empty((x_data.shape[0],12,28,28), dtype=output_details["dtype"])
    for i in range(len(x_data)):
        interpreter.set_tensor(input_details["index"], [x_data[i]])
        interpreter.invoke()
        predictions[i] = np.copy(interpreter.get_tensor(output_details["index"])[0])
    # Dequantize output
    output_scale, output_zero_point = output_details["quantization"]
    if (output_scale, output_zero_point) != (0.0, 0):
        predictions = predictions.astype(np.float32)
        predictions = (predictions - output_zero_point) * output_scale
    # todo reshape output into array for each exit
    return torch.from_numpy(predictions).to('cuda')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument('--model_path', type=str, default=None, help='Path to the model')
    parser.add_argument("--output_dir", type=str, default="./results/image_compression/mcucoder/", help="Directory to save output images")
    return parser.parse_args()

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name)
        
        # Ensure image is in 768x512 orientation
        if image.size[0] < image.size[1]:
            image = image.transpose(method=Image.Transpose.ROTATE_90)  # Rotate to 768x512
        
        if self.transform:
            image = self.transform(image)
        
        return image, 0

def get_test_data(args, resize):
    imagenet_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    return CustomImageDataset(root_dir='/data22/aho/KODAK/', transform=imagenet_transform)


def quantization(data, filter_number, codec_setting):
    min_, max_ = codec_setting['min'][filter_number], codec_setting['max'][filter_number]
    data = (data - min_) / (max_ - min_)
    data = data * 255
    data = data.type(dtype=torch.uint8)
    
    quantization_step = 4
    data = data / quantization_step
    data = data.type(dtype=torch.uint8)

    return data

def quantization_and_dequantization(data, filter_number, codec_setting):
    min_, max_ = codec_setting['min'][filter_number], codec_setting['max'][filter_number]
    
    data = (data - min_) / (max_ - min_)
    data = data * 255
    data = data.type(dtype=torch.uint8)
    
    quantization_step = 4
    data = data / quantization_step
    data = data.type(dtype=torch.uint8)
    data = data * quantization_step

    data = data / 255.0
    data = data * (max_ - min_) + min_
    return data


def quantization_and_huffman(data, filter_number, codec_setting):
    data = data.reshape(-1)
    quantized_data = quantization(data, filter_number, codec_setting).cpu().numpy()
    codec = codec_setting['codec'][filter_number]
    encoded = codec.encode(quantized_data)
    return len(encoded) / 1024

def psnr_batch(img1, img2):
    mse = F.mse_loss(img1, img2, reduction='none')
    mse = mse.view(mse.size(0), -1).mean(dim=1)
    psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr_values

def ms_ssim_to_db(ms_ssim):
    return -10 * np.log10(1 - ms_ssim)

def ms_ssim_batch(img1, img2, data_range=1.0):

    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range)
    
    # Calculate MS-SSIM for each image pair in the batch
    ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))]
    
    # Convert MS-SSIM values to dB
    ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values]
    
    return ms_ssim_db_values
    
def save_image(img, path):
    """ Save a tensor as an image file """
    plt.imsave(path, np.transpose(img, (1, 2, 0)))


def divide_into_patches(tensor, patch_size=224):    
    # Get the dimensions
    B, C, H, W = tensor.shape
    
    # Calculate number of patches along height and width
    num_patches_height = H // patch_size
    num_patches_width = W // patch_size
    
    # Extract patches
    patches = tensor.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    
    # Reshape to desired output
    patches = patches.view(B, num_patches_height * num_patches_width, C, patch_size, patch_size)
    
    return patches
    
def pad(x, p):
    h, w = x.size(2), x.size(3)
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )
    return x_padded, (padding_left, padding_right, padding_top, padding_bottom)

def crop(x, padding):
    return F.pad(
        x,
        (-padding[0], -padding[1], -padding[2], -padding[3]),
    )

def divide_image_to_patches(image, patch_size=224, stride=224):
    b, c, h, w = image.size()
    assert h % patch_size == 0 and w % patch_size == 0, "Image dimensions should be divisible by the patch size"
    
    patches = image.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
    patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
    patches = patches.view(b * (h // patch_size) * (w // patch_size), c, patch_size, patch_size)
    
    return patches


def merge_patches_to_image(patches, original_shape, patch_size=224):
    b, c, h, w = original_shape
    num_patches_h = h // patch_size
    num_patches_w = w // patch_size
    patches = patches.view(b, num_patches_h, num_patches_w, c, patch_size, patch_size)
    patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
    image = patches.view(b, c, h, w)
    return image

    
    return image
    
def create_codec(test_dataset, model):
    codec_setting = {
        'min': {},
        'max': {},
        'codec': {}
    }
    temp_loader = DataLoader(test_dataset, batch_size=32, num_workers=1)
    images, labels = next(iter(temp_loader))
    images = images.to('cuda')

    # prepring kodak dataset
    images, padding = pad(images, 224)
    patches = divide_image_to_patches(images)

    
    for i in tqdm(range(12)):
        #
        encoded = encode_TFLite("MCUCoder.tflite", patches)
        #
        # encoded = model.encoder(patches)
        
        data = encoded[:, i, :, :].reshape(-1).detach().clone()
        min_, max_ = torch.min(data), torch.max(data)
        data = ((data - min_) / (max_ - min_) * 255).type(torch.uint8)
        data = (data / 4).type(torch.uint8).cpu().numpy()
        
        # adding dummy data for covering all of the possible values
        for j in range(0,63):
            if j not in data:
                data = np.append(data,j)

        codec = HuffmanCodec.from_data(data)
        codec_setting['min'][i], codec_setting['max'][i], codec_setting['codec'][i] = min_, max_, codec

    del temp_loader
    return codec_setting
    
def eval_model(model, used_filter, test_dataset, batch_size, codec_setting, output_dir):    
    size_list = []
    psnr_list = []
    ms_ssim_list = []
    all_feature_sizes = []
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1)
    
    for idx, data in enumerate(test_loader):
        images, labels = data
        images = images.to('cuda')
        patches = divide_into_patches(images)
        
        # prepring kodak dataset
        padded_images, padding = pad(images, 224)
        patches = divide_image_to_patches(padded_images)
        
        # model quantization on TFLite
        encoded = encode_TFLite("MCUCoder.tflite", patches )
        #
        # encoded = model.encoder(patches)

                
        # KODAK:dropping
        model.replace_value = 0
        encoded = model.rate_less(encoded)

        for i in range(encoded.size(0)):
            for j in range(used_filter):
                encoded[i, j] = quantization_and_dequantization(encoded[i, j], j, codec_setting)
                
        outputs = model.decoder(encoded)
        
        # KODAK: bringing back the size
        outputs = merge_patches_to_image(outputs, padded_images.shape)
        outputs = crop(outputs, padding)

        # calculate size
        for image in encoded:
            feature_map_size_list = []
            for feature_map in range(used_filter):
                data_size = quantization_and_huffman(image[feature_map], feature_map, codec_setting)
                feature_map_size_list.append(data_size)
            size_list.append(np.sum(feature_map_size_list))

        # save image
        if idx == 0:
            save_image(outputs[0].to('cpu').detach().numpy(), os.path.join(output_dir, f"reconstructed_{int(model.p*12)}.png"))

        # Calculate PSNR for the batch
        psnr_values = psnr_batch(outputs.detach().cpu(), images.detach().cpu())
        psnr_list.extend(psnr_values)
        
        # Calculate ms ssim for the batch
        ms_ssim_values = ms_ssim_batch(outputs.detach().cpu(), images.detach().cpu())
        ms_ssim_list.append(ms_ssim_values)
        
    return np.mean(size_list), np.mean(psnr_list), np.mean(ms_ssim_list)



def main():
    args = parse_args()
    torch.manual_seed(0)
    random.seed(10)
    np.random.seed(0)
    
    model = MCUCoder()
    state_dict = torch.load(args.model_path, map_location='cuda')
    model.load_state_dict(state_dict)
    model = model.to('cuda')
    model.eval()
    
    test_dataset = get_test_data(args, 224)
    
    codec_setting = create_codec(test_dataset, model)

    results = []

    for used_filter in range(1, 13):
        model.p = used_filter / 12
        images_size, average_psnr, average_ms_ssim = eval_model(model,
                                         used_filter = used_filter,                   
                                         test_dataset=test_dataset,
                                         batch_size =args.batch_size,
                                         codec_setting=codec_setting,
                                         output_dir = args.output_dir
                                         )
        # Convert image size to bits
        images_size = images_size * 8 * 1024
        
        # Calculate bpp (bits per pixel)
        avg_bpp = images_size / (224*224)

        
        results.append([float(avg_bpp), float(average_psnr), float(average_ms_ssim)])
        print(f'used_filter: {used_filter}, bpp: {avg_bpp:.2f}, PSNR: {average_psnr:.2f}dB, MS_SSIM: {average_ms_ssim:.2f}dB' )
    print(results)
    
if __name__ == "__main__":
    main()
