
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import torch
import random
import numpy as np
import torchvision
from torchvision import transforms
import argparse
from torch.utils.data import DataLoader
from models.model import MCUCoder
from tqdm import tqdm
from dahuffman import HuffmanCodec
import matplotlib.pyplot as plt  # Import matplotlib
from torchvision import datasets, transforms
import torch.nn.functional as F
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import tensorflow as tf
import seaborn as sns

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

sns.set(font_scale=1.2)
plt.rc('legend', fontsize=10)
sns.set_style("whitegrid")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size for training')
    parser.add_argument('--model_path', type=str, default=None, help='Path to the model')
    parser.add_argument("--imagenet_root", type=str, default="/data22/datasets/ilsvrc2012/", help="ImageNet dataset root directory")
    parser.add_argument("--output_dir", type=str, default="./results/image_compression/mcucoder", help="Directory to save output images")
    return parser.parse_args()

def predict_TFLite(model_path, X):
    x_data = np.copy(X.to('cpu').numpy()) # the function quantizes the input, so we must make a copy
    # Initialize the TFLite interpreter
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    # Inputs will be quantized
    input_scale, input_zero_point = input_details["quantization"]
    if (input_scale, input_zero_point) != (0.0, 0):
        x_data = x_data / input_scale + input_zero_point
        x_data = x_data.astype(input_details["dtype"])
    # Invoke the interpreter
    predictions = np.empty((x_data.shape[0],12,28,28), dtype=output_details["dtype"])
    for i in range(len(x_data)):
        interpreter.set_tensor(input_details["index"], [x_data[i]])
        interpreter.invoke()
        predictions[i] = np.copy(interpreter.get_tensor(output_details["index"])[0])
    # Dequantize output
    output_scale, output_zero_point = output_details["quantization"]
    if (output_scale, output_zero_point) != (0.0, 0):
        predictions = predictions.astype(np.float32)
        predictions = (predictions - output_zero_point) * output_scale
    # todo reshape output into array for each exit
    return torch.from_numpy(predictions).to('cuda')

import torch
import matplotlib.pyplot as plt
import numpy as np

import torch
import matplotlib.pyplot as plt
import numpy as np

def save_latent_images(latent_tensor, path):
    """
    Saves the 12 latent images (each 28x28) from a tensor of size (12, 28, 28) in a single plot.
    Additionally, it displays the mean and standard deviation of each filter below each filter.
    
    Parameters:
        latent_tensor (torch.Tensor): The input tensor of size (12, 28, 28).
        path (str): The path to save the output plot.
    """
    # Check if the tensor is on the GPU and move it to the CPU if necessary
    if latent_tensor.is_cuda:
        latent_tensor = latent_tensor.cpu()
    
    # Convert the tensor to numpy for plotting
    latent_images = latent_tensor.detach().numpy()

    fig, axes = plt.subplots(4, 3, figsize=(8*1.5, 6*1.5))  # Adjusted figure size to accommodate text
    axes = axes.flatten()
    
    for i in range(12):
        ax = axes[i]
        ax.imshow(latent_images[i], cmap='gray')
        ax.axis('off')
        ax.set_title(f'Channel {i+1}')
        
        # Calculate mean and standard deviation
        mean_val = np.mean(latent_images[i])
        std_val = np.std(latent_images[i])
        
        # # Add mean and std text below each plot
        # text = f'avg: {mean_val:.2f}-std: {std_val:.2f}'
        # ax.text(0.5, -0.15, text, ha='center', va='center', transform=ax.transAxes)
        
    plt.subplots_adjust(hspace=-.7, wspace=-1)  # Adjust the vertical space between rows (increase as needed)
    plt.tight_layout()
    plt.savefig(path, bbox_inches='tight')
    plt.close()


    
def get_test_data(args, resize):
    imagenet_transform = transforms.Compose([
         transforms.Resize(224), 
        transforms.CenterCrop(224),
        # transforms.RandomCrop(crop),
        transforms.ToTensor(),
    ])
    return CustomImageDataset(root_dir='/data22/aho/KODAK/', transform=imagenet_transform)
    
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image, 0

# def get_test_data(args, resize):
#     transform_test = transforms.Compose([
#         transforms.Resize((resize, resize), antialias=True),
#         transforms.ToTensor(),
#     ])

#     ImageNet_data = datasets.ImageFolder(root=f'{args.imagenet_root}/val/', transform=transform_test)
#     ImageNet_val, _ = torch.utils.data.random_split(ImageNet_data, [1_000, len(ImageNet_data) - 1_000], generator=torch.Generator().manual_seed(41),)
#     return ImageNet_val


def quantization(data, filter_number, codec_setting):
    min_, max_ = codec_setting['min'][filter_number], codec_setting['max'][filter_number]
    data = (data - min_) / (max_ - min_)
    data = data * 255
    data = data.type(dtype=torch.uint8)
    
    quantization_step = 4
    data = data / quantization_step
    data = data.type(dtype=torch.uint8)

    return data

def quantization_and_dequantization(data, filter_number, codec_setting):
    min_, max_ = codec_setting['min'][filter_number], codec_setting['max'][filter_number]
    
    data = (data - min_) / (max_ - min_)
    data = data * 255
    data = data.type(dtype=torch.uint8)
    
    quantization_step = 4
    data = data / quantization_step
    data = data.type(dtype=torch.uint8)
    data = data * quantization_step

    data = data / 255.0
    data = data * (max_ - min_) + min_
    return data

def quantization_and_huffman(data, filter_number, codec_setting):
    data = data.reshape(-1)
    quantized_data = quantization(data, filter_number, codec_setting).cpu().numpy()
    codec = codec_setting['codec'][filter_number]
    encoded = codec.encode(quantized_data)
    return len(encoded) / 1024

def psnr_batch(img1, img2):
    mse = F.mse_loss(img1, img2, reduction='none')
    mse = mse.view(mse.size(0), -1).mean(dim=1)
    psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr_values

def ms_ssim_to_db(ms_ssim):
    return -10 * np.log10(1 - ms_ssim)

def ms_ssim_batch(img1, img2, data_range=1.0):

    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range)
    
    # Calculate MS-SSIM for each image pair in the batch
    ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))]
    
    # Convert MS-SSIM values to dB
    ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values]
    
    return ms_ssim_db_values
    
def save_image(img, path):
    """ Save a tensor as an image file """
    plt.imsave(path, np.transpose(img, (1, 2, 0)))
    
def eval_model(model, used_filter, test_dataset, batch_size, codec_setting, output_dir):    
    size_list = []
    psnr_list = []
    ms_ssim_list = []
    all_feature_sizes = []
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=1)
    
    for idx, data in enumerate(test_loader):
        images, labels = data
        images = images.to('cuda')
        
        #
        encoded = predict_TFLite("MCUCoder.tflite", images)
        #
        # encoded = model.encoder(images)
                
        # dropping
        encoded = model.rate_less(encoded)
        
        for i in range(encoded.size(0)):
            for j in range(used_filter):
                encoded[i, j] = quantization_and_dequantization(encoded[i, j], j, codec_setting)
                
        outputs = model.decoder(encoded)

        # calculate size
        for image in encoded:
            feature_map_size_list = []
            for feature_map in range(used_filter):
                data_size = quantization_and_huffman(image[feature_map], feature_map, codec_setting)
                feature_map_size_list.append(data_size)
            size_list.append(np.sum(feature_map_size_list))

        # save image
        if idx == 0:
            save_image(outputs[0].to('cpu').detach().numpy(), os.path.join(output_dir, f"reconstructed_{int(model.p*12)}.png"))
        if model.p == 12/12:
            save_latent_images(encoded[0].to('cpu').detach(), 'Plots/latent.pdf')
            
        # Calculate PSNR for the batch
        psnr_values = psnr_batch(outputs.detach().cpu(), images.detach().cpu())
        psnr_list.extend(psnr_values)
        
        # Calculate ms ssim for the batch
        ms_ssim_values = ms_ssim_batch(outputs.detach().cpu(), images.detach().cpu())
        ms_ssim_list.append(ms_ssim_values)
    return np.mean(size_list), np.mean(psnr_list), np.mean(ms_ssim_list)

def create_codec(test_dataset, model):
    codec_setting = {
        'min': {},
        'max': {},
        'codec': {}
    }
    temp_loader = DataLoader(test_dataset, batch_size=5000, num_workers=1)
    images, labels = next(iter(temp_loader))
    images = images.to('cuda')

    for i in tqdm(range(12)):
        #
        encoded = predict_TFLite("MCUCoder.tflite", images)
        #
        # encoded = model.encoder(images)
        
        data = encoded[:, i, :, :].reshape(-1).detach().clone()
        min_, max_ = torch.min(data), torch.max(data)
        data = ((data - min_) / (max_ - min_) * 255).type(torch.uint8)
        data = (data / 4).type(torch.uint8).cpu().numpy()
        
        # adding dummy data for covering all of the possible values
        for j in range(0,63):
            if j not in data:
                data = np.append(data,j)

        codec = HuffmanCodec.from_data(data)
        codec_setting['min'][i], codec_setting['max'][i], codec_setting['codec'][i] = min_, max_, codec

    del temp_loader
    return codec_setting

def main():
    args = parse_args()
    torch.manual_seed(0)
    random.seed(10)
    np.random.seed(0)
    
    model = MCUCoder()
    state_dict = torch.load(args.model_path, map_location='cuda')
    model.load_state_dict(state_dict)
    model = model.to('cuda')
    model.eval()
    
    test_dataset = get_test_data(args, 224)
    
    codec_setting = create_codec(test_dataset, model)

    results = []

    for used_filter in range(1, 13):
        model.p = used_filter / 12
        avg_images_size, avg_psnr, avg_ms_ssim = eval_model(model,
                                         used_filter = used_filter,                   
                                         test_dataset=test_dataset,
                                         batch_size =args.batch_size,
                                         codec_setting=codec_setting,
                                         output_dir = args.output_dir
                                         )
        # Convert image size to bits
        avg_images_size = avg_images_size * 8 * 1024
        
        # Calculate bpp (bits per pixel)
        avg_bpp = avg_images_size / (224 * 224)
        
        results.append((float(avg_bpp), float(avg_psnr), float(avg_ms_ssim)))
        print(f'used_filter: {used_filter}, bpp: {avg_bpp:.2f}, PSNR: {avg_psnr:.2f}dB, MS_SSIM: {avg_ms_ssim:.2f}dB' )
    print(results)
    
    
if __name__ == "__main__":
    main()
