import torch
import torch.utils.data as data
import os
import numpy as np
import utils 
import pdb
from sklearn.cluster import KMeans, DBSCAN
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import pdb
import numpy as np
import torch.utils.data as data
import utils
from options import *
from config import *
from ucf_test import test
from model import *
import os
from dataset_loader import *
from tqdm import tqdm
import torch.nn.functional as F


q = 0.07



class UCF_crime():
    def __init__(self, root_dir, modal, mode, num_segments, len_feature, seed=-1, is_normal=None, is_cluster=None):
        if seed >= 0:
            utils.set_seed(seed)
        self.mode = mode
        self.modal = modal
        self.num_segments = num_segments
        self.len_feature = len_feature
        self.is_cluster = is_cluster
        split_path = os.path.join('list','UCF_{}.list'.format(self.mode))
        split_file = open(split_path, 'r')
        self.vid_list = []
        for line in split_file:
            self.vid_list.append(line.split())
        split_file.close()
        if self.mode == "Train":
            if is_normal is True:
                self.vid_list = self.vid_list[8100:]

            elif is_normal is False:
                self.vid_list = self.vid_list[:8100]
            else:
                assert (is_normal == None)
                print("Please sure is_normal=[True/False]")
                self.vid_list=[]

    def __len__(self):
        return len(self.vid_list)
    
    def temp_clustering(self, video_features, q):
        video_features = torch.tensor(video_features, dtype=torch.float32)

        similarity = F.cosine_similarity(video_features[:-1], video_features[1:], dim=1)
        threshold = torch.quantile(similarity, q, dim=0, keepdim=True).to(video_features.device)

        print(threshold)

        clusters = [0]  # Start with first segment in cluster 0
        cluster_members = {0: [0]}  # Track members of each cluster (cluster ID -> list of segment indices)
        current_cluster = 0

        for i in range(1, video_features.size(0)):
            prev_cluster_members = cluster_members[current_cluster]
            
            similarities_to_prev_cluster = F.cosine_similarity(
                video_features[i].unsqueeze(0), 
                video_features[prev_cluster_members]
            )
            avg_similarity = similarities_to_prev_cluster.mean().item()
            
            if avg_similarity >= threshold:
                clusters.append(current_cluster)
                cluster_members[current_cluster].append(i)
            else:
                current_cluster += 1
                clusters.append(current_cluster)
                cluster_members[current_cluster] = [i]
        return clusters


    def getitem1(self):
        global q

        previous_file = ''
    
        if os.path.exists(previous_file):
            print(f"Loading previous data from {previous_file}...")
            previous_data = torch.load(previous_file)
            _idx_data = [torch.tensor(x) for x in previous_data['_idx_data'].tolist()]
            _video_info = previous_data['_video_info']
            _cluster_labels = [torch.tensor(x) for x in previous_data['_cluster_labels'].tolist()]  # Convert each element to tensor
        else:
            _idx_data = []
            _video_info = []  # This will remain as a list of strings since paths cannot be converted to tensors directly
            _cluster_labels = []  # List to store idx, video_info, best_k, and cluster_labels

        total_steps = 8100 # Set this to the total number of steps (adjust as needed)
        save_directory = "ucf_temporal_cluster"
        os.makedirs(save_directory, exist_ok=True)

        for index in range(0, 8100):

            if index % 10 == 0: 
                print(f"Processing index: {index}")

                data, label, idx, cluster_labels = self.get_data(index)

                # Save the values for this index in the list
                video_info = self.vid_list[index][0]

                # Append the values
                _idx_data.append(torch.tensor(idx))
                # pdb.set_trace()  # Assuming `idx` can be converted to tensor (e.g., scalar or vector)
                _video_info.append(video_info)       # Keep video_info as a string (path) in the list
                _cluster_labels.append(torch.tensor(cluster_labels))  # Assuming cluster_labels is an array-like structure

                # Check if it's time to save (every 100 steps)
                if index == total_steps-10:  
                    print(f"Saving at step {index + 1}...")
                    
                    _idx_data_tensor = torch.stack(_idx_data, dim=0) if len(_idx_data) > 0 else None
                    _cluster_labels_tensor = torch.stack(_cluster_labels, dim=0) if len(_cluster_labels) > 0 else None

                    # Create the dictionary to save
                    to_save = {
                        '_idx_data': _idx_data_tensor,
                        '_video_info': _video_info,  # Keep paths as strings
                        '_cluster_labels': _cluster_labels_tensor
                    }

                    # Construct the full file path for saving
                    file_path = os.path.join(save_directory, f"ucf_temporal_cluster_q_{q}_{index + 1}.pth")

                    # Save the file
                    torch.save(to_save, file_path)
                    print(f"Files saved at step {index + 1}!")
                    # print(_idx_data_tensor)

        print("Finished processing and saving at the specified intervals.")



    def get_data(self, index):
        global q
        vid_info = self.vid_list[index][0]
        name = vid_info.split("/")[-1].split("_x264")[0]
        video_feature = np.load(vid_info).astype(np.float32)   

        if "Normal" in vid_info.split("/")[-1]:
            label = 0
        else:
            label = 1
        if self.mode == "Train":
            new_feat = np.zeros((self.num_segments, video_feature.shape[1])).astype(np.float32)
            r = np.linspace(0, len(video_feature), self.num_segments + 1, dtype = int)
            visual_length = len(video_feature)
            for i in range(self.num_segments):
                if r[i] != r[i+1]:
                    new_feat[i,:] = np.mean(video_feature[r[i]:r[i+1],:], 0)
                else:
                    new_feat[i:i+1,:] = video_feature[r[i]:r[i]+1,:]
            video_feature = new_feat
            if "Normal" in vid_info.split("/")[-1]:
                cluster_labels = np.zeros(len(video_feature))   
            else:
                cluster_labels = self.temp_clustering(video_feature, q)

        return video_feature, label, index, cluster_labels   
        

    def visual_cluster(self, index):
        global q
        vid_info = self.vid_list[index][0]
        name = vid_info.split("/")[-1].split("_x264")[0]
        video_feature = np.load(vid_info).astype(np.float32) 
        new_feat = np.zeros((self.num_segments, video_feature.shape[1])).astype(np.float32)
        r = np.linspace(0, len(video_feature), self.num_segments + 1, dtype = int)
        visual_length = len(video_feature)
        for i in range(self.num_segments):
            if r[i] != r[i+1]:
                new_feat[i,:] = np.mean(video_feature[r[i]:r[i+1],:], 0)
            else:
                new_feat[i:i+1,:] = video_feature[r[i]:r[i]+1,:]
        video_feature = new_feat
        new_feat = np.zeros((self.num_segments, video_feature.shape[1])).astype(np.float32)
        r = np.linspace(0, len(video_feature), self.num_segments + 1, dtype = int) 
        print(name)

        self.temp_clustering(video_feature, q) 


        

if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()

    config = Config(args)
    worker_init_fn = None


    abnormal_train_loader = UCF_crime(root_dir = config.root_dir, mode = 'Train', modal = config.modal, num_segments = 200, len_feature = config.len_feature, is_normal = False, is_cluster = True)
    
    abnormal_train_loader.getitem1()        
    # abnormal_train_loader.visual_cluster(0)
    # abnormal_train_loader.visual_cluster(10)


    # pdb.set_trace()
