import torch
import torch.utils.data as data
import os
import numpy as np
import utils 
import pdb


class UCF_crime(data.DataLoader):
    def __init__(self, root_dir, modal, mode, num_segments, len_feature, seed=-1, is_normal=None, is_cluster=None, cluster_file=None, is_uncertainty=None, uncertainty_file=None):
        if seed >= 0:
            utils.set_seed(seed)
        self.mode = mode
        self.modal = modal
        self.num_segments = num_segments
        self.len_feature = len_feature
        self.is_cluster = is_cluster
        self.is_uncertainty = is_uncertainty
        split_path = os.path.join('list','UCF_{}.list'.format(self.mode))
        split_file = open(split_path, 'r')
        self.vid_list = []
        for line in split_file:
            self.vid_list.append(line.split())
        split_file.close()
        if self.mode == "Train":
            if is_normal is True:
                self.vid_list = self.vid_list[8100:]
            elif is_normal is False:
                self.vid_list = self.vid_list[:8100]
            else:
                assert (is_normal == None)
                print("Please sure is_normal=[True/False]")
                self.vid_list=[]
        
        if cluster_file is not None:
            self.cluster_data = torch.load(cluster_file)
        else:
            self.cluster_data = None
        if self.is_uncertainty:
            self.uncertainty_data = torch.zeros(len(self.vid_list), self.num_segments)


    def update_uncertainty(self, new_uncertainty_tensor):
        assert new_uncertainty_tensor.shape == (len(self.vid_list), self.num_segments)
        self.uncertainty_data = new_uncertainty_tensor.cpu()
        print("Uncertainty tensor updated.")



    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, index):
        
        if self.mode == "Test":
            data,label,name = self.get_data(index)
            return data,label,name
        else:
            if self.is_cluster and self.is_uncertainty:
                data, label, index, cluster_labels, uncertainty_score = self.get_data(index)
                return data, label, index, cluster_labels, uncertainty_score
                
            elif self.is_cluster:  # Then check for cluster alone
                data, label, index, cluster_labels = self.get_data(index)
                return data, label, index, cluster_labels
                
            elif self.is_uncertainty:  # Then check for uncertainty alone
                data, label, index, uncertainty_scores = self.get_data(index)
                return data, label, index, uncertainty_scores

            else:  # No cluster or uncertainty
                data, label, index = self.get_data(index)
                # return data, label, index, m_label
                return data, label, index

    def get_data(self, index):
        vid_info = self.vid_list[index][0]  
        name = vid_info.split("/")[-1].split("_x264")[0]
        video_feature = np.load(vid_info).astype(np.float32)   

        if "Normal" in vid_info.split("/")[-1]:
            label = 0
        else:
            label = 1
        if self.mode == "Train":
            new_feat = np.zeros((self.num_segments, video_feature.shape[1])).astype(np.float32)
            r = np.linspace(0, len(video_feature), self.num_segments + 1, dtype = np.int64)
            for i in range(self.num_segments):
                if r[i] != r[i+1]:
                    new_feat[i,:] = np.mean(video_feature[r[i]:r[i+1],:], 0)
                else:
                    new_feat[i:i+1,:] = video_feature[r[i]:r[i]+1,:]
            video_feature = new_feat
        if self.mode == "Test":
            return video_feature, label, name      
        else:
            if "Normal" in vid_info.split("/")[-1]:
                cluster_labels = torch.zeros(len(video_feature))  
                uncertainty_scores = torch.zeros(len(video_feature))  
            else:
                if self.is_cluster:
                    if self.cluster_data is not None:
                        mapped_idx = (index // 10) * 10 
                        _idx_data = self.cluster_data["_idx_data"].tolist()
                        if mapped_idx in _idx_data:
                            idx_pos = _idx_data.index(mapped_idx)
                            cluster_labels = self.cluster_data["_cluster_labels"][idx_pos]
                        else:
                            raise ValueError(f"Index {mapped_idx} not found in _idx_data.")
                    else:
                        print("Cluster file not loaded")
                
                if self.is_uncertainty:
                    # pdb.set_trace()  
                    if self.uncertainty_data is not None:
                        # Debugging line to check the state of uncertainty_data
                        mapped_idx = (index // 10)  # Load uncertainty using the video name
                        if mapped_idx in range(self.uncertainty_data.shape[0]):  
                            uncertainty_scores = self.uncertainty_data[mapped_idx]
                        else:
                            raise ValueError(f"Mapped index {mapped_idx} is out of range (0-949)")
                    else:
                        print("Uncertainty file not loaded")

            if self.is_cluster and self.is_uncertainty:
                return video_feature, label, index, cluster_labels, uncertainty_scores
            elif self.is_cluster:
                return video_feature, label, index, cluster_labels
            elif self.is_uncertainty:
                return video_feature, label, index, uncertainty_scores
            else:
                return video_feature, label, index 

class XDVideo(data.DataLoader):
    def __init__(self, root_dir, mode, modal, num_segments, len_feature, seed=-1, is_normal=None, is_cluster=None, cluster_file=None, is_uncertainty=None):
        if seed >= 0:
            utils.set_seed(seed)
        self.data_path=root_dir
        self.mode=mode
        self.modal=modal
        self.num_segments = num_segments
        self.len_feature = len_feature
        self.is_cluster = is_cluster
        self.is_uncertainty = is_uncertainty

        if self.modal == 'all':
            self.feature_path = []
            if self.mode == "Train":
                for _modal in ['RGB', 'Flow']:
                    self.feature_path.append(os.path.join(self.data_path, "i3d-features",_modal))
            else:
                for _modal in ['RGBTest', 'FlowTest']:
                    self.feature_path.append(os.path.join(self.data_path, "i3d-features",_modal))
        else:
            self.feature_path = os.path.join(self.data_path, modal)
        split_path = os.path.join("list",'XD_{}.list'.format(self.mode))
        split_file = open(split_path, 'r',encoding="utf-8")
        self.vid_list = []
        for line in split_file:
            self.vid_list.append(line.split())
        split_file.close()
        if self.mode == "Train":
            if is_normal is True:
                self.vid_list = self.vid_list[9525:]
            elif is_normal is False:
                self.vid_list = self.vid_list[:9525]
            else:
                assert (is_normal == None)
                print("Please sure is_normal = [True/False]")
                self.vid_list=[]

        if cluster_file is not None:
            self.cluster_data = torch.load(cluster_file)
        else:
            self.cluster_data = None
        if self.is_uncertainty:
            self.uncertainty_data = torch.zeros(len(self.vid_list), self.num_segments)
        else:
            self.uncertainty_data = None
            
    def update_uncertainty(self, new_uncertainty_tensor, step=None):
        assert new_uncertainty_tensor.shape == (len(self.vid_list), self.num_segments)
        self.uncertainty_data = new_uncertainty_tensor.cpu()
        print("Uncertainty tensor updated.")
    

    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, index):
        if self.mode == "Test":
            data,label, index = self.get_data(index)
            return data, label, index
        else:
            if self.is_cluster and self.is_uncertainty:
                data, label, index, cluster_labels, uncertainty_score = self.get_data(index)
                return data, label, index, cluster_labels, uncertainty_score
                
            elif self.is_cluster:  
                data, label, index, cluster_labels = self.get_data(index)
                return data, label, index, cluster_labels
                
            elif self.is_uncertainty:
                data, label, index, uncertainty_scores = self.get_data(index)
                return data, label, index, uncertainty_scores

            else:  
                data, label, index = self.get_data(index)
                return data, label, index

    def get_data(self, index):
        vid_name = self.vid_list[index][0]
        label=0
        if "_label_A" not in vid_name:
            label=1  
        video_feature = np.load(os.path.join(self.feature_path[0],
                                vid_name )).astype(np.float32)
        if self.mode == "Train":
            new_feature = np.zeros((self.num_segments,self.len_feature)).astype(np.float32)
            sample_index = utils.random_perturb(video_feature.shape[0],self.num_segments)
            for i in range(len(sample_index)-1):
                if sample_index[i] == sample_index[i+1]:
                    new_feature[i,:] = video_feature[sample_index[i],:]
                else:
                    new_feature[i,:] = video_feature[sample_index[i]:sample_index[i+1],:].mean(0)
                    
            video_feature = new_feature
        if self.mode == "Test":
            return video_feature, label, index   
        else:
            if label==0:
                cluster_labels = torch.zeros(len(video_feature))  
                uncertainty_scores = torch.zeros(len(video_feature))
            else:
                if self.is_cluster:
                    if self.cluster_data is not None:
                        mapped_idx = (index // 5) * 5
                        _idx_data = self.cluster_data["_idx_data"].tolist()
                        if mapped_idx in _idx_data:
                            idx_pos = _idx_data.index(mapped_idx)
                            cluster_labels = self.cluster_data["_cluster_labels"][idx_pos]
                        else:
                            raise ValueError(f"Index {mapped_idx} not found in _idx_data.")
                    else:
                        print("Cluster file not loaded")
                if self.is_uncertainty:
                    if self.uncertainty_data is not None:
                        mapped_idx = (index // 5)  # Load uncertainty using the video name
                        if mapped_idx in range(self.uncertainty_data.shape[0]):  
                            uncertainty_scores = self.uncertainty_data[mapped_idx]
                        else:
                            raise ValueError(f"Mapped index {mapped_idx} is out of range (0-949)")
                    else:
                        print("Uncertainty file not loaded")

            if self.is_cluster and self.is_uncertainty:
                return video_feature, label, index, cluster_labels, uncertainty_scores
            elif self.is_cluster:
                return video_feature, label, index, cluster_labels
            elif self.is_uncertainty:
                return video_feature, label, index, uncertainty_scores
            else:
                return video_feature, label, index 
