import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from models.model import MCUCoder
from tqdm import tqdm
from dahuffman import HuffmanCodec
import torch.nn.functional as F
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
import argparse
import random
import tensorflow as tf



def encode_TFLite(model_path, X):
    x_data = np.copy(X.to('cpu').numpy()) # the function quantizes the input, so we must make a copy
    # Initialize the TFLite interpreter
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]
    # Inputs will be quantized
    input_scale, input_zero_point = input_details["quantization"]
    if (input_scale, input_zero_point) != (0.0, 0):
        x_data = x_data / input_scale + input_zero_point
        x_data = x_data.astype(input_details["dtype"])
    # Invoke the interpreter
    predictions = np.empty((x_data.shape[0],12,28,28), dtype=output_details["dtype"])
    for i in range(len(x_data)):
        interpreter.set_tensor(input_details["index"], [x_data[i]])
        interpreter.invoke()
        predictions[i] = np.copy(interpreter.get_tensor(output_details["index"])[0])
    # Dequantize output
    output_scale, output_zero_point = output_details["quantization"]
    if (output_scale, output_zero_point) != (0.0, 0):
        predictions = predictions.astype(np.float32)
        predictions = (predictions - output_zero_point) * output_scale
    # todo reshape output into array for each exit
    return torch.from_numpy(predictions).to('cuda')


    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=1024, help='Batch size for training')
    parser.add_argument('--model_path', type=str, default=None, help='Path to the model')
    parser.add_argument("--video_folder", type=str, help="Path to the folder containing video files")
    parser.add_argument("--output_dir", type=str, default="./results/mcucoder", help="Directory to save output images and videos")
    return parser.parse_args()

class VideoDataset(Dataset):
    def __init__(self, video_path, transform=None):
        self.video_path = video_path
        self.transform = transform
        self.frames = self._load_frames()

    def _load_frames(self):
        cap = cv2.VideoCapture(self.video_path)
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))
        cap.release()
        return frames

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, frame_index):
        frame = self.frames[frame_index]
        if self.transform:
            frame = self.transform(frame)
        return frame, frame_index

def get_video_frames(video_path, resize):
    transform = transforms.Compose([
        transforms.Resize((resize, resize), antialias=True),
        transforms.ToTensor(),
    ])
    video_dataset = VideoDataset(video_path=video_path, transform=transform)
    return video_dataset, len(video_dataset)

def quantization(data, filter_number, codec, step=4):
    min_, max_ = codec['min'][filter_number], codec['max'][filter_number]
    data = (data - min_) / (max_ - min_)
    data = data * 255
    data = data.type(dtype=torch.uint8)
    
    quantization_step = step
    data = data / quantization_step
    data = data.type(dtype=torch.uint8)

    return data

def quantization_and_dequantization(data, filter_number, codec, step=4):
    min_, max_ = codec['min'][filter_number], codec['max'][filter_number]
    
    data = (data - min_) / (max_ - min_)
    data = data * 255
    data = data.type(dtype=torch.uint8)
    
    quantization_step = step
    data = data / quantization_step
    data = data.type(dtype=torch.uint8)
    data = data * quantization_step

    data = data / 255.0
    data = data * (max_ - min_) + min_
    return data

def quantization_and_huffman(data, filter_number, codec, step=4):
    data = data.reshape(-1)
    quantized_data = quantization(data, filter_number, codec, step).cpu().numpy()
    codec = codec['codec'][filter_number]
    encoded = codec.encode(quantized_data)
    return len(encoded) / 1024

def psnr_batch(img1, img2):
    mse = F.mse_loss(img1, img2, reduction='none')
    mse = mse.view(mse.size(0), -1).mean(dim=1)
    psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr_values

def ms_ssim_to_db(ms_ssim):
    return -10 * np.log10(1 - ms_ssim)

def ms_ssim_batch(img1, img2, data_range=1.0):
    ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range)
    
    ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))]
    ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values]
    
    return ms_ssim_db_values

def save_image(img, path):
    plt.imsave(path, np.transpose(img, (1, 2, 0)))


    
def eval_model(model, used_filter, test_dataset, batch_size, codec, output_dir, output_video_dir, video_name):
    size_list = []
    psnr_list = []
    ms_ssim_list = []
    all_feature_sizes = []
    
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_out_path = os.path.join(output_video_dir, f'{video_name}_reconstructed_filter_{used_filter}.mp4')
    video_writer = None
    
    for frame_index, data in enumerate(test_loader):
        images, indices = data
        images = images.to('cuda')
        
        #
        encoded = encode_TFLite("MCUCoder.tflite", images)
        #
        # encoded = model.encoder(images)
        
        model.replace_value = 0
        encoded = model.rate_less(encoded)

        for j in range(used_filter):
            encoded[0, j] = quantization_and_dequantization(encoded[0, j], j, codec) 

        outputs = model.decoder(encoded)

        if video_writer is None:
            h, w = outputs[0].shape[1:3]
            video_writer = cv2.VideoWriter(video_out_path, fourcc, 20, (w, h))
        
        output_frame = (outputs[0].cpu().detach().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
        output_frame = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
        video_writer.write(output_frame)
        
        feature_map_size_list = []
        for feature_map in range(used_filter):
            data_size = quantization_and_huffman(encoded[0][feature_map], feature_map, codec)
            feature_map_size_list.append(data_size)
            
        size_list.append(np.sum(feature_map_size_list))
        
        if frame_index in [0, 50, 100]:
            save_image(outputs[0].to('cpu').detach().numpy(), os.path.join(output_dir, f"{video_name}_reconstructed_{int(model.p*12)}_frame_index_{frame_index}.png"))


        psnr_values = psnr_batch(outputs.detach().cpu(), images.detach().cpu())
        psnr_list.extend(psnr_values)
        
        ms_ssim_values = ms_ssim_batch(outputs.detach().cpu(), images.detach().cpu())
        ms_ssim_list.extend(ms_ssim_values)

    video_writer.release()

    return size_list, psnr_list, ms_ssim_list

def create_codec(test_dataset, model):
    codec_setting = {
        'min': {},
        'max': {},
        'codec': {}
    }
    temp_loader = DataLoader(test_dataset, batch_size=5000, num_workers=1)
    images, labels = next(iter(temp_loader))
    images = images.to('cuda')
    

    for i in range(12):
        # encoding on TFLite with int8 quantaized model
        encoded = encode_TFLite("MCUCoder.tflite", images)
        #
        # encoded = model.encoder(images)
        
        data = encoded[:, i, :, :].reshape(-1).detach().clone()
        min_, max_ = torch.min(data), torch.max(data)
        data = ((data - min_) / (max_ - min_) * 255).type(torch.uint8)
        data = (data / 4).type(torch.uint8).cpu().numpy()
        
        for j in range(0, 63):
            if j not in data:
                data = np.append(data, j)

        codec = HuffmanCodec.from_data(data)
        codec_setting['min'][i], codec_setting['max'][i], codec_setting['codec'][i] = min_, max_, codec

    del temp_loader
    return codec_setting
    
def main():
    args = parse_args()
    torch.manual_seed(0)
    random.seed(10)
    np.random.seed(0)
    
    model = MCUCoder()
    state_dict = torch.load(args.model_path, map_location='cuda')
    model.load_state_dict(state_dict)
    model = model.to('cuda')
    model.eval()
    
    video_files = [f for f in os.listdir(args.video_folder) if os.path.isfile(os.path.join(args.video_folder, f)) and f.endswith(('.mp4', '.avi'))]

    output_video_dir = os.path.join(args.output_dir, 'output_videos')
    os.makedirs(output_video_dir, exist_ok=True)
    
    all_for_print = []
    
    for used_filter in range(1, 13):
        model.p = used_filter / 12
        all_vids_avg_frames_bpps = []
        all_vids_avg_frames_psnrs = []
        all_vids_avg_frames_ms_ssims = []
        
        for video_file in video_files:
            video_path = os.path.join(args.video_folder, video_file)
            video_name, _ = os.path.splitext(video_file)
            test_dataset, number_of_frames = get_video_frames(video_path, 224)
            codec = create_codec(test_dataset, model)
    
            frames_sizes, frames_psnrs, frames_ms_ssims = eval_model(model,
                                                                    used_filter=used_filter,                   
                                                                    test_dataset=test_dataset,
                                                                    batch_size=args.batch_size,
                                                                    codec=codec,
                                                                    output_dir=args.output_dir,
                                                                    output_video_dir=output_video_dir,
                                                                    video_name=video_name)
    
            # Convert image size to bits
            frames_sizes = [size * 8 * 1024 for size in frames_sizes]
            frames_bpps = [size / (224 * 224) for size in frames_sizes]
            
            avg_frames_bpps = np.mean(frames_bpps)
            avg_frames_psnrs = np.mean(frames_psnrs)
            avg_frames_ms_ssims = np.mean(frames_ms_ssims)

            all_vids_avg_frames_bpps.append(avg_frames_bpps)
            all_vids_avg_frames_psnrs.append(avg_frames_psnrs)
            all_vids_avg_frames_ms_ssims.append(avg_frames_ms_ssims)

        all_for_print.append([float(np.mean(all_vids_avg_frames_bpps)), float(np.mean(all_vids_avg_frames_psnrs)), float(np.mean(all_vids_avg_frames_ms_ssims))])
        print(f'used_filter: {used_filter}, bpp: {np.mean(all_vids_avg_frames_bpps)}, '
              f'Average PSNR: {np.mean(all_vids_avg_frames_psnrs):.2f}dB, Average MS_SSIM: {np.mean(all_vids_avg_frames_ms_ssims):.2f}dB')
        
    print(all_for_print)
if __name__ == "__main__":
    main()
