from sys import stderr
from torch.utils import data
from PIL import Image
import torch
import os
import random
import numpy as np



class VFSD(data.Dataset):
    """Dataset class for the my own dataset."""

    def __init__(
        self,
        params,
        data_dir,
        transform,
        mode,
        num_clip,
        sample,
        augmentation: bool,
        use_face: bool,
        use_video: bool,
        use_flow: bool,
        dataset_ratio: float = 1.0,
    ):
        """Initialize and preprocess the dataset.
        the split are in order according to the fold number
        """
        self.params=params
        self.data_dir = data_dir
        self.transform = transform
        self.num_clip = num_clip
        self.sample = sample
        self.augmentation = augmentation

        self.use_face = use_face
        self.use_video = use_video
        self.use_flow = use_flow

        self.dataset_ratio = dataset_ratio

        self._preprocess(mode)

    def _preprocess(self, mode):
        """Preprocess the attribute file."""

        label2class = {"stressed": 0, "unstressed": 1}
        # ids = os.listdir(self.data_dir)
        # id_numbers = [int(i) for i in ids]
        # id_numbers.sort()
        """
            Temporary dataset splitting by person index
        """
        train_id=""
        with open("train_sample.txt", "r") as file:
              train_id = file.readlines()[0].split(" ")
        val_id=""
        with open("val_sample.txt", "r") as file:
              val_id = file.readlines()[0].split(" ")
        test_id=""
        with open("test_sample.txt", "r") as file:
              test_id = file.readlines()[0].split(" ")
        if mode == "train":
            self.indices = train_id
        elif mode == "test":
            self.indices = test_id
        elif mode == "val":
            self.indices = val_id
        """
            End
        """

        txt_dataset="/root/dataset/stressvideo/rsl/txt"
        label_count = {"stressed": 0, "unstressed": 0}
        self.class_count = [0, 0]

        self.dataset = []
        for person_id in self.indices:
            person_path = os.path.join(
                self.data_dir, str(person_id)
            )  # person path, dataset_partial/0
            sample_names = os.listdir(person_path)
            txt_path=os.path.join(
                txt_dataset, str(person_id)
            )
            for sample_name in sample_names:
                # sample_name: {video_id}_{stress level}, e.g. 2_happy, 0-0000_happy
                sample_path = os.path.join(
                    person_path, sample_name
                )  # sample path, dataset_partial/0/2_happy
                txt_subpath=os.path.join(
                    txt_path, sample_name
                )
                label = sample_name.split("_")[-1]
                class_id = label2class[label]

                sample = [class_id]


                if self.use_video:
                    video_path = os.path.join(
                        sample_path, "video"
                        # sample_path, "face"
                    )  # dataset_partial/0/2_happy/video
                    emo_path = os.path.join(
                        sample_path, "emo"
                        # sample_path, "face"
                    )  # dataset_partial/0/2_happy/video
                    if not os.path.exists(video_path):
                        # print(f"Directory Not Found: {video_path}", file=stderr)
                        continue
                    if not os.path.exists(emo_path):
                        # print(f"Directory Not Found: {video_path}", file=stderr)
                        continue

                    frames = os.listdir(video_path)
                    frames = [os.path.join(video_path, frame) for frame in frames]
                    while len(frames) < self.num_clip:
                        frames += frames
                    frames = frames[: self.num_clip]
                    frames.sort()

                    emos = os.listdir(emo_path)
                    emos = [os.path.join(emo_path, frame) for frame in emos]
                    while len(emos) < self.num_clip:
                        emos += emos
                    emos = emos[: self.num_clip]
                    emos.sort()

 

                    sample.append(frames)
                    sample.append(emos)
                    sample.append(txt_subpath)
                    
                    # print("sample",sample[1])

                self.dataset.append(sample)

                label_count[label] += 1
                self.class_count[class_id] += 1
        self.num_samples = len(self.dataset)

        print(f"{mode} length: {self.num_samples}, class_count: {self.class_count}")
        print(f"labels: {label_count}")

    def __getitem__(self, index):
        """Return one (E, N) image pair and its corresponding emotion label."""
        label = self.dataset[index][0]
        res = [torch.tensor(label)]
            

        interval_video = int(self.num_clip / self.sample)
        interval_emo=int(self.num_clip / 8)
        # print("interval",interval)
        if self.use_video:
            video_files = self.dataset[index][1]
            video = []
            for i in range(self.sample):
                with Image.open(video_files[i * interval_video]) as video_frame:
                    video_frame = self.transform(video_frame)
                video.append(video_frame)
            video = torch.stack(video, dim=0)
            # print("reslen",len(video))
            video=video.permute((1, 0, 2, 3))
            res.append(video)

            emo_files = self.dataset[index][2]
            emo_label_set=[]
            emo=[]
            # print("emo_files",len(emo_files))
            for i in range(8):
                set_ori=emo_files[i*interval_emo:i*interval_emo+interval_emo]
                # print(set_ori)
                if self.params["countEmo"]:
                    emo_clip,emo_label=self.__countEmo__(set_ori)
                    # print("emo_clip",emo_clip)
                else:
                    emo_clip=self.__random__(set_ori)
                with Image.open(emo_clip) as emo_frame:
                    emo_frame = self.transform(emo_frame)
                emo.append(emo_frame)
                emo_feature=[0,0,0,0,0,0,0]
                emo_feature[emo_label]=1
                emo_label_set.append(test_normalized(emo_feature))
                
            emo = torch.stack(emo, dim=0)#3*8*224*224
            # print("reslen",len(video))
            emo=emo.permute((1, 0, 2, 3))
            res.append(emo)


            pth=[]
            # print("erroe",self.dataset[index][3])
            # first=torch.load(os.path.join(self.dataset[index][3],os.listdir(self.dataset[index][3])[0]))
            for txt in os.listdir(self.dataset[index][3]):          
                    pth.append(torch.load(os.path.join(self.dataset[index][3],txt)))
            first=torch.stack([pth[0],pth[1],pth[2],pth[3],pth[4],pth[5],pth[6],pth[7]],dim=0)
            # print("reslenpth",len(video))
            res.append(first)
        
        return res

    def __len__(self):
        """Return the number of images."""
        return self.num_samples
    
    def __countEmo__(self,emo_set):
        emo_count_list=[0,0,0,0,0,0,0]
        emo_int_list=[0,0,0,0,0,0,0]
        for e in emo_set:
            emo=int(e.split("label")[1].split("score")[0])
            intensity=float(e.split("label")[1].split("score")[1][:-4])
            emo_count_list[emo]+=1
            if emo_int_list[emo]<intensity:
                emo_int_list[emo]=intensity
        max_emo=emo_count_list.index(max(emo_count_list))
        max_intensity=emo_int_list[max_emo]
        res=""
        for e in emo_set:
            emo=int(e.split("label")[1].split("score")[0])
            intensity=float(e.split("label")[1].split("score")[1][:-4])
            if emo==max_emo and intensity==max_intensity:
                res=e
                break
        # print("emo_count_list",emo_count_list)
        # print("emo_int_list",max_intensity)
        # print("e",e)
        return e,max_emo
    

    def __random__(self,emo_set):
        emo_count_list=[0,0,0,0,0,0,0]
        emo_int_list=[0,0,0,0,0,0,0]
        
        a=random.randint(0, len(emo_set)-1)
        # print("emo_count_list",emo_count_list)
        # print("emo_int_list",max_intensity)
        # print("e",e)
        return emo_set[a]


from lib.transforms import Transforms
# from transforms import Transforms

def get_test_loader(
    params,
    data_dir,
    crop_size=[224],
    batch_size=16,
    num_workers=1,
    num_clip=50,
    sample=10,
    augmentation=False,
    use_face=False,
    use_video=True,
    use_flow=False,
    drop_last=False,
):

    transform = Transforms(crop_size, augmentation)

    test_static_dataset = VFSD(
        params,
        data_dir,
        transform,
        "test",
        num_clip,
        sample,
        augmentation,
        use_face,
        use_video,
        use_flow,
    )

    test_static_data_loader = data.DataLoader(
        dataset=test_static_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=drop_last,
        persistent_workers=num_workers > 0,
    )
    return test_static_data_loader


def get_val_loader(
    params,
    data_dir,
    crop_size=[224],
    batch_size=16,
    num_workers=1,
    num_clip=50,
    sample=10,
    augmentation=True,
    use_face=False,
    use_video=True,
    use_flow=False,
    drop_last=False,
):
    transform = Transforms(crop_size, augmentation)

    val_static_dataset = VFSD(
        params,
        data_dir,
        transform,
        "val",
        num_clip,
        sample,
        augmentation,
        use_face,
        use_video,
        use_flow,
    )

    val_static_data_loader = data.DataLoader(
        shuffle=False,
        dataset=val_static_dataset,
        num_workers=num_workers,
        batch_size=batch_size,
        drop_last=drop_last,
        persistent_workers=num_workers > 0,
    )

    return val_static_data_loader

def test_normalized(feature):

    cnt=0
    for i in feature:
        if i!=0:
            cnt+=1
    if cnt==0:
        return [0,0,0,0,0,0,0]
    maxx = np.mean(feature, axis= 0 ) + 2 * np.std(feature, axis = 0) 
    # print (maxx)
    feature = feature - np.mean(feature)  
    normalized_feature = feature / np.max(np.abs(feature)) 
    normalized_feature=torch.from_numpy(np.float32(normalized_feature))
    # normalized_feature = feature / maxx
    # normalized_feature=(feature-np.std(feature, axis = 0))/np.mean(feature, axis= 0 )
    # print ("nor",normalized_feature)
    # print("normalized_feature",normalized_feature)
    return normalized_feature

def get_train_loader(
    params,
    data_dir,
    crop_size,
    batch_size,
    num_workers,
    num_clip,
    sample,
    augmentation=True,
    use_face=False,
    use_video=True,
    use_flow=False,
    training_set_ratio=1.0,
    shuffle=True,
    drop_last=False,
):
    transform = Transforms(crop_size, augmentation)

    train_static_dataset = VFSD(
        params,
        data_dir,
        transform,
        "train",
        num_clip,
        sample,
        augmentation,
        use_face,
        use_video,
        use_flow,
        training_set_ratio,
    )

    train_static_data_loader = data.DataLoader(
        dataset=train_static_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
    )

    return train_static_data_loader

# if __name__ == '__main__':
    # main()
    # 
    


    # chatpgt_token()
    # get_train_loader(params,
    #     params['dataset'],
    #     [224,224],
    #     # [64,64],
    #     params['batch_size'],
    #     params['num_workers'],
    #     params['clip_len'],
    #     64,
    #     False,
    #     False,
    #     True,
    #     False,)