import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from feeders import tools

class MultimodalFeeder(Dataset):
    def __init__(self, skeleton_data_path, event_data_dir, label_path, p_interval=1, split='train', 
                 random_choose=False, random_shift=False, random_move=False, random_rot=False, 
                 window_size=-1, normalization=False, debug=False, use_mmap=False, bone=False, vel=False):
        self.debug = debug
        self.skeleton_data_path = skeleton_data_path
        self.event_data_dir = event_data_dir
        self.label_path = label_path
        self.split = split
        self.random_choose = random_choose
        self.random_shift = random_shift
        self.random_move = random_move
        self.window_size = window_size
        self.normalization = normalization
        self.use_mmap = use_mmap
        self.p_interval = p_interval
        self.random_rot = random_rot
        self.bone = bone
        self.vel = vel

        self.load_data()  # Load skeleton data and event file mappings
        if normalization:
            self.get_mean_map()

    def load_data(self):
        # Load skeleton data from the .npz file
        npz_data = np.load(self.skeleton_data_path)
        if self.split == 'train':
            self.skeleton_data = npz_data['x_train']
            self.label = np.where(npz_data['y_train'] > 0)[1]
        elif self.split == 'test':
            self.skeleton_data = npz_data['x_test']
            self.label = np.where(npz_data['y_test'] > 0)[1]
        else:
            raise NotImplementedError('Data split only supports train/test')

        N, T, _ = self.skeleton_data.shape
        self.skeleton_data = self.skeleton_data.reshape((N, T, 2, 25, 3)).transpose(0, 4, 1, 3, 2)

        # Initialize list to hold event file paths and corresponding event labels
        self.event_file_map = []
        self.event_labels = []

        # Read label_path to extract the event file paths and event labels
        with open(self.label_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                name, label = line.strip().split()

                # Clean up the sample name (remove b'' prefix if present)
                if name.startswith("b'") and name.endswith("'"):
                    name = name[2:-1]
                
                video_filename = f"{name}_dvs.avi"
                video_filepath = os.path.join(self.event_data_dir, video_filename)

                if os.path.exists(video_filepath):
                    self.event_file_map.append(video_filepath)
                    self.event_labels.append(int(label))  # Store the event label
                else:
                    raise FileNotFoundError(f"Video file {video_filepath} not found.")

    def get_mean_map(self):
        data = self.skeleton_data
        N, C, T, V, M = data.shape
        self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0)
        self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1))

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        # Load skeleton data
        skeleton_data_numpy = self.skeleton_data[index]
        label = self.label[index]
        skeleton_data_numpy = np.array(skeleton_data_numpy)
        valid_frame_num = np.sum(skeleton_data_numpy.sum(0).sum(-1).sum(-1) != 0)

        # Preprocess skeleton data
        skeleton_data_numpy = tools.valid_crop_resize(skeleton_data_numpy, valid_frame_num, self.p_interval, self.window_size)
        if self.random_rot:
            skeleton_data_numpy = tools.random_rot(skeleton_data_numpy)
        if self.bone:
            from .bone_pairs import ntu_pairs
            bone_data_numpy = np.zeros_like(skeleton_data_numpy)
            for v1, v2 in ntu_pairs:
                bone_data_numpy[:, :, v1 - 1] = skeleton_data_numpy[:, :, v1 - 1] - skeleton_data_numpy[:, :, v2 - 1]
            skeleton_data_numpy = bone_data_numpy
        if self.vel:
            skeleton_data_numpy[:, :-1] = skeleton_data_numpy[:, 1:] - skeleton_data_numpy[:, :-1]
            skeleton_data_numpy[:, -1] = 0

        # Load corresponding event data using index-based file map
        event_video_path = self.event_file_map[index]
        event_data_tensor = self.read_event_video(event_video_path)
        event_data_tensor = self.valid_crop_resize_event(event_data_tensor, event_data_tensor.shape[0], self.p_interval, self.window_size)

        # Get event label
        event_label = self.event_labels[index]
        if event_label != label:
            raise ValueError(f"Mismatch between skeleton label ({label}) and event label ({event_label}) at index {index}.")

        return {'skeleton': skeleton_data_numpy, 'event': event_data_tensor}, label, index

    def read_event_video(self, video_path, positive_threshold=127, negative_threshold=127):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise IOError(f"Cannot open video file {video_path}")

        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        video_tensor = []

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            frame_tensor = np.zeros((2, frame_height, frame_width), dtype=np.int8)
            frame_tensor[0, :, :] = (gray_frame >= 170).astype(np.int8)
            frame_tensor[1, :, :] = (gray_frame <= 80).astype(np.int8)
            video_tensor.append(torch.tensor(frame_tensor))

        cap.release()
        video_tensor = torch.stack(video_tensor)
        return video_tensor

    def valid_crop_resize_event(self, event_tensor, valid_frame_num, p_interval, window):
        N, C, H, W = event_tensor.shape
        begin = 0
        end = valid_frame_num
        valid_size = end - begin

        if len(p_interval) == 1:
            p = p_interval[0]
            bias = int((1 - p) * valid_size / 2)
            data = event_tensor[begin + bias:end - bias, :, :, :]  # center_crop
            cropped_length = data.shape[0]
        else:
            p = np.random.rand(1) * (p_interval[1] - p_interval[0]) + p_interval[0]
            cropped_length = np.minimum(np.maximum(int(np.floor(valid_size * p)), 64), valid_size)
            bias = np.random.randint(0, valid_size - cropped_length + 1)
            data = event_tensor[begin + bias:begin + bias + cropped_length, :, :, :]

        data = data.permute(1, 2, 3, 0).contiguous().view(C * H * W, cropped_length)
        data = data[None, None, :, :]

        data = data.float()

        data = F.interpolate(data, size=(C * H * W, window), mode='nearest').squeeze()

        data = (data > 0.5).int()  # Optional re-binarization
        data = data.contiguous().view(C, H, W, window).permute(3, 0, 1, 2).contiguous()

        return data
