import torch
import einops
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from vqlm.vqvae_muse import get_tokenizer_muse

class ActionParser():
    def __init__(self, tokenizer):
        assert tokenizer != None
        self.tokenizer = tokenizer
    
    def decode_tokens(self, ids):
        # Check if ids are tensors and have shape (batch_size, 256)
        if not isinstance(ids, torch.Tensor):
            raise TypeError("ids must be a torch.Tensor")
        if ids.dim() != 2 or ids.size(1) != 256:
            raise ValueError("ids must have shape (batch size, 256)")
        
        # make sure ids are on the same device as the tokenizer
        ids_copy = ids.clone()
        ids_copy = ids_copy.to(self.tokenizer.device)
    
        imgs = einops.rearrange(
            torch.clamp(self.tokenizer.decode_code(ids_copy), 0.0, 1.0),
            'b c h w -> b h w c'
        ).detach().cpu().numpy()

        return imgs

    @classmethod
    def get_pixel_location(cls, img, level, crop_ratio=0.0):
        height, width = img.shape[0:2]
        if height != width or height != 256:
            raise ValueError("Resolution of the imgae should be 256 times 256!")
        
        new_height = level * (height // level)
        new_width = level * (width // level)
        if new_height != height or new_width != width:
            img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

        grid_size = new_height // level
        coordinates = {}
        for i in range(level):
            for j in range(level):
                top_left_x = i * grid_size
                top_left_y = j * grid_size
                bottom_right_x = (i + 1) * grid_size
                bottom_right_y = (j + 1) * grid_size
                
                crop_px = int(grid_size * crop_ratio)
                cropped_img = img[
                    top_left_x + crop_px:bottom_right_x - crop_px,
                    top_left_y + crop_px:bottom_right_y - crop_px
                ]
                coordinates[(i, j)] = cropped_img
   
        return coordinates

    @classmethod
    def get_action(cls, start_coord, next_coord):
        ACTIONS = {
            (-1,  0): (0, 'up'),
            ( 1,  0): (1, 'down'),
            ( 0, -1): (2, 'left'),
            ( 0,  1): (3, 'right')
        }
        delta = (next_coord[0] - start_coord[0], next_coord[1] - start_coord[1])
        return ACTIONS.get(delta, (-1, 'invalid'))

    @classmethod
    def get_coordinate_from_state(cls, state, level):
        assert state < level*level, "state must be less than level*level"
        row = state // level
        col = state % level
        return (row, col)

    @classmethod
    def visualize_imgs(cls, input_img, pred_img):
        def to_uint8(img):
            if img.dtype == bool:
                return img.astype(np.uint8) * 255
            elif img.dtype == np.float32 or img.dtype == np.float64:
                if img.max() <= 1.0:
                    return (img * 255).astype(np.uint8)
                else:
                    return img.astype(np.uint8)
            elif img.dtype == np.uint8:
                return img
            else:
                raise TypeError(f"Unsupported image dtype: {img.dtype}")

        input_img = to_uint8(input_img)
        pred_img = to_uint8(pred_img)

        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(input_img, cmap='gray', vmin=0, vmax=255)
        axs[0].set_title('Input Image')
        axs[0].axis('off')

        axs[1].imshow(pred_img, cmap='gray', vmin=0, vmax=255)
        axs[1].set_title('Pred Image')
        axs[1].axis('off')

        plt.tight_layout()
        plt.show()

    @classmethod 
    def visualize_ids(cls, input_ids, pred_ids, tokenizer):
        def decode_tokens(ids, tokenizer):
            # Check if ids are tensors and have shape (batch_size, 256)
            if not isinstance(ids, torch.Tensor):
                raise TypeError("ids must be a torch.Tensor")
            if ids.dim() != 2 or ids.size(1) != 256:
                raise ValueError("ids must have shape (batch size, 256)")
            
            # make sure ids are on the same device as the tokenizer
            ids_copy = ids.clone()
            ids_copy = ids_copy.to(tokenizer.device)
        
            imgs = einops.rearrange(
                torch.clamp(tokenizer.decode_code(ids_copy), 0.0, 1.0),
                'b c h w -> b h w c'
            ).detach().cpu().numpy()
            return imgs
        input_img = decode_tokens(input_ids, tokenizer)[0]
        pred_img = decode_tokens(pred_ids, tokenizer)[0]
        ActionParser.visualize_imgs(input_img, pred_img)


    def parse_action_in_ids(self, input_ids, pred_ids, level, start_coord, target_coord):
        if not isinstance(input_ids, torch.Tensor) or not isinstance(pred_ids, torch.Tensor):
            raise TypeError("input_ids and pred_ids must be torch.Tensor")
        
        start_coord = self.get_coordinate_from_state(start_coord, level) if isinstance(start_coord, int) else start_coord
        target_coord = self.get_coordinate_from_state(target_coord, level) if isinstance(target_coord, int) else target_coord

        input_img = self.decode_tokens(input_ids)[0]
        pred_img = self.decode_tokens(pred_ids)[0]

        input_gray = cv2.cvtColor((input_img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        pred_gray = cv2.cvtColor((pred_img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        
        input_coordinates = self.get_pixel_location(input_gray, level)
        pred_coordinates = self.get_pixel_location(pred_gray, level)

        def compute_iou(img1, img2, thresh=200, is_show = False):
            _, bin1 = cv2.threshold(img1, thresh, 255, cv2.THRESH_BINARY)
            _, bin2 = cv2.threshold(img2, thresh, 255, cv2.THRESH_BINARY)
            bin1_bool = bin1 == 0
            bin2_bool = bin2 == 0
            if is_show:
                self.visualize_imgs(bin1_bool, bin2_bool)
            intersection = np.logical_and(bin1_bool, bin2_bool).sum()
            union = np.logical_or(bin1_bool, bin2_bool).sum()
            iou = 0.0
            if union > 0:
                iou = intersection / union
            return iou

        def compute_mse(img1, img2):
            return np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
        
        mse_values = {}
        iou_values = {}
        for coord in input_coordinates.keys():
            mse_values[coord] = compute_mse(input_coordinates[coord], pred_coordinates[coord])
            iou_values[coord] = compute_iou(input_coordinates[start_coord], pred_coordinates[coord])

        sorted_mse = sorted(mse_values.items(), key=lambda x: x[1], reverse=True)
        sorted_iou = sorted(iou_values.items(), key=lambda x: x[1], reverse=True)

        most_changed_coords = [sorted_mse[i][0] for i in range(min(2, len(sorted_mse)))]
        least_changed_coords = [sorted_iou[i][0] for i in range(min(2, len(sorted_iou)))]

        extracted_coord = target_coord if target_coord in most_changed_coords else least_changed_coords[0]
        if extracted_coord not in most_changed_coords:
            extracted_coord = start_coord

        action = self.get_action(start_coord, extracted_coord)

        return {
            "action": action,
            "pred_coord" : extracted_coord
        }
    
    def parse_action_in_imgs(self, input_img, pred_img, level, start_coord, target_coord):
        
        start_coord = self.get_coordinate_from_state(start_coord, level) if isinstance(start_coord, int) else start_coord
        target_coord = self.get_coordinate_from_state(target_coord, level) if isinstance(target_coord, int) else target_coord

        input_gray = cv2.cvtColor((input_img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        pred_gray = cv2.cvtColor((pred_img * 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)
        
        input_coordinates = self.get_pixel_location(input_gray, level)
        pred_coordinates = self.get_pixel_location(pred_gray, level)

        def compute_iou(img1, img2, thresh=200, is_show = False):
            _, bin1 = cv2.threshold(img1, thresh, 255, cv2.THRESH_BINARY)
            _, bin2 = cv2.threshold(img2, thresh, 255, cv2.THRESH_BINARY)
            bin1_bool = bin1 == 0
            bin2_bool = bin2 == 0
            if is_show:
                self.visualize_imgs(bin1_bool, bin2_bool)
            intersection = np.logical_and(bin1_bool, bin2_bool).sum()
            union = np.logical_or(bin1_bool, bin2_bool).sum()
            iou = 0.0
            if union > 0:
                iou = intersection / union
            return iou

        def compute_mse(img1, img2):
            return np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
        
        mse_values = {}
        iou_values = {}
        for coord in input_coordinates.keys():
            mse_values[coord] = compute_mse(input_coordinates[coord], pred_coordinates[coord])
            iou_values[coord] = compute_iou(input_coordinates[start_coord], pred_coordinates[coord])

        sorted_mse = sorted(mse_values.items(), key=lambda x: x[1], reverse=True)
        sorted_iou = sorted(iou_values.items(), key=lambda x: x[1], reverse=True)

        most_changed_coords = [sorted_mse[i][0] for i in range(min(2, len(sorted_mse)))]
        least_changed_coords = [sorted_iou[i][0] for i in range(min(2, len(sorted_iou)))]

        extracted_coord = target_coord if target_coord in most_changed_coords else least_changed_coords[0]
        if extracted_coord not in most_changed_coords:
            extracted_coord = start_coord

        action = self.get_action(start_coord, extracted_coord)

        return {
            "action": action,
            "pred_coord" : extracted_coord
        }
    

if __name__ == "__main__":
    # seed everything
    torch.manual_seed(42)
    np.random.seed(42)
    torch_device = 'cuda'
    tokenizer = get_tokenizer_muse().to(torch_device)

    # Read the JSONL file
    jsonl_file_path = 'dataset/frozen_lake/tokenized_dataset/SFT/train_dataset.jsonl'

    with open(jsonl_file_path, 'r') as file:
        data = [json.loads(line) for line in file][:5]

    # Extract input_ids and pred_ids
    input_ids_list = [item['input_tokens'] for item in data]
    pred_ids_list = [item['output_tokens'] for item in data]
    meta_list = [item['meta'] for item in data]
    input_state_list = [item['input_state'] for item in data]

    parser = ActionParser(tokenizer)
    idx = 0
    for input_ids, pred_ids, meta, input_state in zip(input_ids_list, pred_ids_list, meta_list, input_state_list):
        print(f"Evaluating {idx}th data")
        input_ids_tensor = torch.tensor(input_ids).view(-1, 256).to(torch_device)
        pred_ids_tensor = torch.tensor(pred_ids).view(-1, 256).to(torch_device)
        action_dict = parser.parse_action_in_ids(input_ids_tensor, pred_ids_tensor, meta['level'], input_state, meta['target_pos'])
        pred_coord = action_dict['pred_coord']

        print(f"Action: {action_dict['action']}")
        print(f"Pred coord: {(pred_coord)}")
        ActionParser.visualize_ids(input_ids_tensor, pred_ids_tensor, tokenizer)