import torch
import torch.utils.data as data
import os
import numpy as np
import utils 
import pdb
import pdb
import numpy as np
import torch.utils.data as data
import utils
from options import *
from config import *
from ucf_test import test
from model import *
import os
from dataset_loader import *
from tqdm import tqdm
import torch.nn.functional as F


q = 0.02



class XDVideo():
    def __init__(self, root_dir, modal, mode, num_segments, len_feature, seed=-1, is_normal=None, is_cluster=None):
        if seed >= 0:
            utils.set_seed(seed)
        self.data_path=root_dir
        self.mode = mode
        self.modal = modal
        self.num_segments = num_segments
        self.len_feature = len_feature
        self.is_cluster = is_cluster
        if self.modal == 'all':
            self.feature_path = []
            if self.mode == "Train":
                for _modal in ['RGB', 'Flow']:
                    self.feature_path.append(os.path.join(self.data_path, "i3d-features",_modal))
            else:
                for _modal in ['RGBTest', 'FlowTest']:
                    self.feature_path.append(os.path.join(self.data_path, "i3d-features",_modal))
        else:
            self.feature_path = os.path.join(self.data_path, modal)
        split_path = os.path.join("list",'XD_{}.list'.format(self.mode))
        split_file = open(split_path, 'r',encoding="utf-8")
        self.vid_list = []
        for line in split_file:
            self.vid_list.append(line.split())
        split_file.close()
        if self.mode == "Train":
            if is_normal is True:
                self.vid_list = self.vid_list[9525:]
            elif is_normal is False:
                self.vid_list = self.vid_list[:9525]
            else:
                assert (is_normal == None)
                print("Please sure is_normal = [True/False]")
                self.vid_list=[]

    def __len__(self):
        return len(self.vid_list)
    
    def temp_clustering(self, video_features, q):
        video_features = torch.tensor(video_features, dtype=torch.float32)

        similarity = F.cosine_similarity(video_features[:-1], video_features[1:], dim=1)
        threshold = torch.quantile(similarity, q, dim=0, keepdim=True).to(video_features.device)

        print(threshold)

        clusters = [0]  # Start with first segment in cluster 0
        cluster_members = {0: [0]}  # Track members of each cluster (cluster ID -> list of segment indices)
        current_cluster = 0

        for i in range(1, video_features.size(0)):
            prev_cluster_members = cluster_members[current_cluster]
            
            similarities_to_prev_cluster = F.cosine_similarity(
                video_features[i].unsqueeze(0), 
                video_features[prev_cluster_members]
            )
            avg_similarity = similarities_to_prev_cluster.mean().item()
            
            if avg_similarity >= threshold:
                clusters.append(current_cluster)
                cluster_members[current_cluster].append(i)
            else:
                current_cluster += 1
                clusters.append(current_cluster)
                cluster_members[current_cluster] = [i]
        return clusters


    def getitem1(self):
        global q

        previous_file = ''
    
        if os.path.exists(previous_file):
            print(f"Loading previous data from {previous_file}...")
            previous_data = torch.load(previous_file)
            _idx_data = [torch.tensor(x) for x in previous_data['_idx_data'].tolist()]
            _video_info = previous_data['_video_info']
            _cluster_labels = [torch.tensor(x) for x in previous_data['_cluster_labels'].tolist()]  # Convert each element to tensor
        else:
            _idx_data = []
            _video_info = []  # This will remain as a list of strings since paths cannot be converted to tensors directly
            _cluster_labels = []  # List to store idx, video_info, best_k, and cluster_labels

        total_steps = 9525 # Set this to the total number of steps (adjust as needed)
        save_directory = "xd_temporal_cluster"
        os.makedirs(save_directory, exist_ok=True)

        for index in range(0, 9525):

            if index % 5 == 0: 
                print(f"Processing index: {index}")

                data, label, idx, cluster_labels = self.get_data(index)

                # Save the values for this index in the list
                video_info = self.vid_list[index][0]

                # Append the values
                _idx_data.append(torch.tensor(idx))
                # pdb.set_trace()  # Assuming `idx` can be converted to tensor (e.g., scalar or vector)
                _video_info.append(video_info)       # Keep video_info as a string (path) in the list
                _cluster_labels.append(torch.tensor(cluster_labels))  # Assuming cluster_labels is an array-like structure

                # Check if it's time to save (every 100 steps)
                if index == total_steps-5:  
                    print(f"Saving at step {index + 1}...")
                    
                    _idx_data_tensor = torch.stack(_idx_data, dim=0) if len(_idx_data) > 0 else None
                    _cluster_labels_tensor = torch.stack(_cluster_labels, dim=0) if len(_cluster_labels) > 0 else None

                    # Create the dictionary to save
                    to_save = {
                        '_idx_data': _idx_data_tensor,
                        '_video_info': _video_info,  # Keep paths as strings
                        '_cluster_labels': _cluster_labels_tensor
                    }

                    # Construct the full file path for saving
                    file_path = os.path.join(save_directory, f"xd_temporal_cluster_q_{q}_{index + 1}.pth")

                    # Save the file
                    torch.save(to_save, file_path)
                    print(f"Files saved at step {index + 1}!")
                    # print(_idx_data_tensor)

        print("Finished processing and saving at the specified intervals.")



    def get_data(self, index):
        global q
        vid_name = self.vid_list[index][0]
        label=0
        if "_label_A" not in vid_name:
            label=1  
        video_feature = np.load(os.path.join(self.feature_path[0],
                                vid_name )).astype(np.float32)
        if self.mode == "Train":
            new_feature = np.zeros((self.num_segments,self.len_feature)).astype(np.float32)
            sample_index = utils.random_perturb(video_feature.shape[0],self.num_segments)
            for i in range(len(sample_index)-1):
                if sample_index[i] == sample_index[i+1]:
                    new_feature[i,:] = video_feature[sample_index[i],:]
                else:
                    new_feature[i,:] = video_feature[sample_index[i]:sample_index[i+1],:].mean(0)
                    
            video_feature = new_feature
            if label==0:
                cluster_labels = np.zeros(len(video_feature))   
            else:
                cluster_labels = self.temp_clustering(video_feature, q)

        return video_feature, label, index, cluster_labels   
        

    def visual_cluster(self, index):
        global q
        vid_name = self.vid_list[index][0]

        video_feature = np.load(os.path.join(self.feature_path[0],
                                vid_name )).astype(np.float32)        
        new_feature = np.zeros((self.num_segments,self.len_feature)).astype(np.float32)
        sample_index = utils.random_perturb(video_feature.shape[0],self.num_segments)
        for i in range(len(sample_index)-1):
            if sample_index[i] == sample_index[i+1]:
                new_feature[i,:] = video_feature[sample_index[i],:]
            else:
                new_feature[i,:] = video_feature[sample_index[i]:sample_index[i+1],:].mean(0)
                
        video_feature = new_feature

        self.temp_clustering(video_feature, q) 


        

if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()

    config = Config(args)
    worker_init_fn = None


    abnormal_train_loader = XDVideo(root_dir = config.root_dir, mode = 'Train', modal = config.modal, num_segments = 200, len_feature = config.len_feature, is_normal = False, is_cluster = True)
    
    abnormal_train_loader.getitem1()        
    # abnormal_train_loader.visual_cluster(0)
    # abnormal_train_loader.visual_cluster(10)


    # pdb.set_trace()
