import os
import random

import cv2
import h5py
import pickle
import editdistance
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from utils.util import _extract_frames_h5py, get_num_frames
from utils.config import argparser


data_path_dict = {
    'break_egg': 'xxx',
    'pour_milk': 'xxx',
    'pour_liquid': 'xxx',
    'tennis': 'xxx'
}

data_video_resolution = {
    'break_egg': [1024, 768],
    'pouring': [300, 534],
    'milk': [640, 360],
    'grab': [640, 360],
    'pour_milk': [640, 360],
    'pour_liquid': [320, 240],
    'tennis_ego': [1920, 1080],
    'tennis_exo': [480, 360]
}


class VideoAlignmentDataset(Dataset):
    def __init__(self, args, mode):
        self.args = args
        self.mode = mode
        self.dataset = args.dataset
        self.data_path = data_path_dict[args.dataset]
        if args.dataset != 'tennis':
            self.video_res = data_video_resolution[args.dataset]
        else:
            self.video_res_ego = data_video_resolution['tennis_ego']
            self.video_res_exo = data_video_resolution['tennis_exo']

        self.num_steps = args.tcc_num_frames
        self.frame_stride = args.tcc_frame_stride
        self.random_offset = args.tcc_random_offset
        # discarded
        # self.num_context_steps = args.tcc_num_context_steps
        # self.frames_per_video = self.num_frames * self.num_context_steps

        if args.dataset in ['break_egg', 'tennis']:
            self.video_paths1 = self._construct_video_path_by_mode(os.path.join(self.data_path, args.view1), mode)
            self.video_paths2 = self._construct_video_path_by_mode(os.path.join(self.data_path, args.view2), mode)
            self.frame_save_path = self.data_path
        else:
            self.video_paths1 = self._construct_video_path(os.path.join(self.data_path, mode, args.view1))
            self.video_paths2 = self._construct_video_path(os.path.join(self.data_path, mode, args.view2))
            self.frame_save_path = os.path.join(self.data_path, mode)

        if args.merge_all:
            tmp_path = list(set(self.video_paths1 + self.video_paths2))
            self.video_paths1 = tmp_path
            self.video_paths2 = tmp_path

        self.image_mean = np.array([0.485, 0.456, 0.406])
        self.image_std = np.array([0.229, 0.224, 0.225])

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self, idx):
        raise NotImplementedError

    def _construct_video_path(self, dir_name):
        video_paths = []
        for item in os.listdir(dir_name):
            if item.endswith('.mp4'):
                video_paths.append(os.path.join(dir_name, item))
        assert len(video_paths) > 1
        print(f'{len(video_paths)} videos in {dir_name}')
        return video_paths

    def _construct_video_path_by_mode(self, dir_name, mode):
        video_paths = []
        f_out = open(os.path.join(dir_name, mode+'.csv'), 'r')
        for line in f_out.readlines():
            line = line.strip()
            video_paths.append(os.path.join(dir_name, line))
        return video_paths

    def get_frames_h5py(self, h5_file_path, frames_list, bbox_list=None):
        final_frames = list()
        h5_file = h5py.File(h5_file_path, 'r')
        frames = h5_file['images']
        for frame_num in frames_list:
            frame_ = frames[frame_num]
            frame = cv2.resize(
                frame_,
                (self.args.tcc_input_size, self.args.tcc_input_size),
                interpolation=cv2.INTER_AREA
            )
            if bbox_list is not None and frame_num == 16:
                min_x, min_y, max_x, max_y = bbox_list[frame_num, 0:4]
                cv2.rectangle(frame, (int(min_x), int(min_y)), (int(max_x), int(max_y)), (0,0,255), thickness=1)
                min_x, min_y, max_x, max_y = bbox_list[frame_num, 4:8]
                cv2.rectangle(frame, (int(min_x), int(min_y)), (int(max_x), int(max_y)), (0,0,255), thickness=1)
                min_x, min_y, max_x, max_y = bbox_list[frame_num, 8:12]
                cv2.rectangle(frame, (int(min_x), int(min_y)), (int(max_x), int(max_y)), (0,0,255), thickness=1)
                # min_x, min_y, max_x, max_y = bbox_list[frame_num, 12:16]
                # cv2.rectangle(frame, (int(min_x), int(min_y)), (int(max_x), int(max_y)), (0, 0, 255), thickness=1)
                # plt.title([int(min_x), int(min_y), int(max_x), int(max_y), h5_file_path.split('/')[-1]])
                plt.imshow(frame[:, :, ::-1])
                plt.axis('off')
                plt.show()

            if self.args.imagenet_norm:
                frame = (frame / 255.0) - self.image_mean
                frame = frame / self.image_std
            else:
                frame = (frame / 127.5) - 1.0

            final_frames.append(frame)

        h5_file.close()
        assert len(final_frames) == len(frames_list)
        return final_frames

    def get_steps(self, step):
        """Sample multiple context steps for a given step."""
        if self.num_steps < 1:
            raise ValueError('num_steps should be >= 1.')
        if self.frame_stride < 1:
            raise ValueError('stride should be >= 1.')
        steps = torch.arange(step - self.frame_stride, step + self.frame_stride, self.frame_stride)
        return steps

    def sample_frames(self, seq_len):
        # sampling_strategy = 'offset_uniform' for now
        assert seq_len >= self.random_offset
        if self.num_steps < seq_len - self.random_offset:
            # random sample
            steps = torch.randperm(seq_len - self.random_offset) + self.random_offset
            steps = steps[:self.num_steps]
            steps = torch.sort(steps)[0]
        else:
            # sample all
            steps = torch.arange(0, self.num_steps, dtype=torch.int64)
        chosen_steps = torch.clamp(steps, 0, seq_len - 1)
        steps = torch.cat(list(map(self.get_steps, steps)), dim=-1)
        steps = torch.clamp(steps, 0, seq_len - 1)
        return chosen_steps, steps


class VideoAlignmentTrainDataset(VideoAlignmentDataset):
    def __init__(self, args, mode):
        super(VideoAlignmentTrainDataset, self).__init__(args, mode)

    def __len__(self):
        return len(self.video_paths1)

    def __getitem__(self, idx):
        selected_videos = [random.sample(self.video_paths1, 1), random.sample(self.video_paths2, 1)]
        final_frames = list()
        seq_lens = list()
        steps = list()
        for video in selected_videos:
            video = video[0]
            video_frames_count = get_num_frames(video)
            main_frames, selected_frames = self.sample_frames(video_frames_count)
            h5_file_name = _extract_frames_h5py(
                video,
                self.frame_save_path
            )
            frames = self.get_frames_h5py(
                h5_file_name,
                selected_frames,
            )
            frames = np.array(frames)  # (64, 168, 168, 3)
            final_frames.append(
                np.expand_dims(frames.astype(np.float32), axis=0)
            )
            steps.append(np.expand_dims(np.array(main_frames), axis=0))
            seq_lens.append(video_frames_count)

            # final_frames.append(
            #     np.expand_dims(frames.astype(np.float32), axis=0)
            # )
            # steps.append(np.expand_dims(np.array(main_frames), axis=0))
            # seq_lens.append(video_frames_count)
            # break

        return (
            np.concatenate(final_frames),
            np.concatenate(steps),
            np.array(seq_lens)
        )


class VideoAlignmentDownstreamDataset(VideoAlignmentDataset):
    def __init__(self, args, mode):
        args.merge_all = True
        super(VideoAlignmentDownstreamDataset, self).__init__(args, mode)
        self.modify_data = args.modify_data
        self.percentage = 0.5
        self.video_paths1 = sorted(self.video_paths1)
        self._load_label()
        self._construct_frame_path()

    def _construct_frame_path(self):
        self.frame_path_list = []
        self.video_len_list = []
        self.video_ego_id = []
        for video in self.video_paths1:
            video_frames_count = get_num_frames(video)
            self.video_len_list.append(video_frames_count)
            video_name = video.replace('.mp4', '').split('/')[-1]
            view = video.split('/')[-2]
            if video_name not in self.label_dict:
                print(f'{video_name} not in dict')
                labels = -1 * np.ones(video_frames_count, dtype=int)  # no ground truth label
            else:
                labels = self.label_dict[video_name]
            assert video_frames_count == len(labels)
            for frame_id in range(video_frames_count):
                self.frame_path_list.append([video, frame_id, labels[frame_id]])
                if view == 'ego':
                    self.video_ego_id.append(1)
                else:
                    self.video_ego_id.append(0)
            # if video_frames_count < 20:
            #     print(video_name, video_frames_count)
        print(f'Finish constructing frames path list, total len {len(self.frame_path_list)}')

    def crop_transition_region(self, vector):
        # Find the index of the transition frame (0 -> 1)
        transition_index = np.where(np.diff(vector) == 1)[0][0] + 1

        # Calculate the percentage region around the transition frame
        total_length = len(vector)
        region_length = int(total_length * self.percentage)

        # Calculate the start and end indices of the region
        start_index = max(0, transition_index - region_length // 2)
        end_index = min(total_length, transition_index + region_length // 2)

        # Replace all values outside the region with -1
        vector[:start_index] = -1
        vector[end_index:] = -1
        return vector

    def modify_vector(self, vector):
        modified_vector = np.zeros(len(vector), dtype=int)

        # Find the index of the transition frame (0 -> 1)
        transition_index = np.where(np.diff(vector) == 1)[0][0] + 1

        # Replace the first half of the original zeros with 0 and the second half with 1
        half_zeros = transition_index // 2
        modified_vector[:half_zeros] = 0
        modified_vector[half_zeros:transition_index] = 1

        # Replace the first half of the original ones with 2 and the second half with 3
        half_ones = (len(vector) - transition_index) // 2
        modified_vector[transition_index:transition_index + half_ones] = 2
        modified_vector[transition_index + half_ones:] = 3
        return modified_vector

    def _load_label(self):
        file_name = 'label_all.pickle' if self.args.label_all else 'label_new.pickle'
        print(f'Loading {file_name}')
        file_path = os.path.join(self.data_path, self.mode, file_name) if self.args.dataset == 'pouring' else os.path.join(self.data_path, file_name)
        with open(file_path, 'rb') as handle:
            label_dict_orig = pickle.load(handle)
        if self.modify_data:
            self.label_dict = {}
            print(f'Modifying')
            for key, value in label_dict_orig.items():
                new_value = self.modify_vector(value)
                self.label_dict[key] = new_value if 'subject' in key else value
        else:

            self.label_dict = label_dict_orig

    def __len__(self):
        return len(self.frame_path_list)

    def __getitem__(self, idx):
        video_path, frame_id, frame_label = self.frame_path_list[idx]
        h5_file_name = _extract_frames_h5py(video_path, self.frame_save_path)
        context_frame_id = max(0, frame_id - self.frame_stride)
        frame = self.get_frames_h5py(h5_file_name, [context_frame_id, frame_id])
        frame = np.array(frame).astype(np.float32)  # (2, 168, 168, 3)
        return frame, frame_label, video_path


def visualize_shuffle(frames, shuffle_num=4, permute_num=4):
    main_frame_idx = [i for i in range(1, 64, 2)]
    main_frames = frames[main_frame_idx]

    segment_size = main_frames.shape[0] // shuffle_num
    shuffle_num = shuffle_num // 2
    # segments = [range(i * segment_size, (i + 1) * segment_size) for i in range(shuffle_num)]
    # subset_indices = [random.choice(segment) for segment in segments]
    # subset_indices = [int((i + 0.5) * segment_size) for i in range(shuffle_num)]
    subset_indices = [int((i + 0.5) * segment_size) for i in range(shuffle_num)]

    # subset_indices = [0, 16, 31]
    # shuffle_num = 3
    print('subset indices', subset_indices)

    x = main_frames[subset_indices]
    x = ((x + 1.0) * 127.5).astype(np.uint8)
    base_idx = torch.arange(x.shape[0])
    images_list = [x]
    for _ in range(permute_num):
        random_idx = torch.randperm(x.shape[0])
        while torch.all(random_idx == base_idx):
            random_idx = torch.randperm(x.shape[0])
        edit_distance = editdistance.eval(base_idx.numpy(), random_idx.numpy()) / x.shape[0]
        print('random indices', random_idx, edit_distance)
        x_shuffled = x[random_idx]
        images_list.append(x_shuffled)
    reverse_idx = torch.arange(x.shape[0]-1, -1, -1)
    images_list.append(x[reverse_idx])
    images = np.array(images_list)
    fig, axs = plt.subplots(permute_num + 2, shuffle_num, figsize=(3 * shuffle_num, 3 * (permute_num + 2)))
    for i in range(permute_num + 2):
        for j in range(shuffle_num):
            axs[i, j].imshow(images[i, j, :, :, ::-1])
            axs[i, j].axis('off')
    plt.show()


if __name__ == '__main__':
    args = argparser.parse_args()
    args.dataset = 'pour_milk'
    mode = 'train'
    # mode = 'test_new'

    dataset = VideoAlignmentTrainDataset(args, mode)
    # dataset = VideoAlignmentTestDataset(args)
    # dataset = VideoAlignmentDownstreamDataset(args, mode)
    data_loader = DataLoader(
        dataset,
        batch_size=2,
        num_workers=0
    )
    print(f'Data loader len {len(data_loader)}')

    for i in range(2):
        print(i)
        frames, _, _ = dataset[i]
        visualize_shuffle(frames[0], 10, 0)
        visualize_shuffle(frames[1], 10, 0)

    # for i, batch in enumerate(data_loader):
    #     print(i, len(data_loader))
    #     break
