import os, random
import torch
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image
import torchvision.transforms as transforms
from pathlib import Path
from einops import rearrange
import json


class DIR_GOD_ImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir='./DIR_dataset', root_dir_nsd='../Datasets/NSD',
                 subjs=[1],
                 subjs_nsd=[1],
                 image_norm=True,
                 phase='train', val_data_fraction=1.0):
        data_path = f'{root_dir}/fmri_npy'
        data_path_god = f"{root_dir.replace('DIR', 'GOD')}/fmri_npy"
        image_path = f'{root_dir}/stimuli/images'
        cap_label_path = f'{root_dir}/stimuli/id_to_full_name.json'


        test_image_path = f'{root_dir_nsd}/nsddata_stimuli/stimuli/images'
        data_path_nsd = f'{root_dir_nsd}/fmri_npy'
        test_cap_label_path = f'{root_dir_nsd}/COCO_73k_annots_curated.npy'
        shared_trial = f'{root_dir_nsd}/sub1257_shared_triallabel.npy'
        unique_trial = f'{root_dir_nsd}/unique_triallabel.npy'



        image_transform_list = [transforms.Resize((224, 224))]
        image_transform_list.append(transforms.ToTensor())
        if image_norm:
            image_transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        self.image_transform = transforms.Compose(image_transform_list)

        self.image_norm = image_norm
        self.data_path = Path(data_path)
        self.data_path_god = Path(data_path_god)
        self.data_path_nsd = Path(data_path_nsd)
        self.data = dict()
        self.images_all_subjs = dict()
        self.image_subj_fmri_dict = dict()
        self.unique_images = dict()
        self.unique_trial_dict = dict()
        self.shared_trial_dict = dict()
        self.train_images_path = image_path
        self.test_images_path = test_image_path
        self.cap_label = json.load(open(cap_label_path, 'r'))  # caption file
        self.test_cap_label = np.load(test_cap_label_path, allow_pickle=True)  # caption file
        self.sub = subjs
        self.sub_nsd = subjs_nsd
        self.phase = phase


        if phase == 'pretrain':

            # ======================================================
            # Shared Trial
            # ======================================================

            # {image_name -> {subj : [fmri1, fmri2, ...]}}
            for subj_idx in self.sub:

                if subj_idx <=3:
                    images_cur_subj = np.array(np.load(self.data_path / f'{subj_idx:02d}_label.npy', allow_pickle=True))
                else:
                    images_cur_subj = np.array(np.load(self.data_path_god / f'{subj_idx-3:02d}_label.npy', allow_pickle=True))

                for fmri_idx in range(images_cur_subj.shape[0]): # img_idx == fmri_idx
                    cur_image_name = images_cur_subj[fmri_idx].decode()
                    if cur_image_name not in self.image_subj_fmri_dict:
                        self.image_subj_fmri_dict[cur_image_name] = dict()

                    if subj_idx not in self.image_subj_fmri_dict[cur_image_name]:
                        self.image_subj_fmri_dict[cur_image_name][subj_idx] = []

                    self.image_subj_fmri_dict[cur_image_name][subj_idx].append(fmri_idx) # subject -> labels

        elif phase == 'train':
            # ======================================================
            # Unique Trial
            # ======================================================
            train_label = np.load(unique_trial, allow_pickle=True)  # train index for each subject in dict
            self.train_dict_map = dict()
            for subj in self.sub_nsd:
                temp_image = np.array(list(train_label.item()[subj - 1]))  # start from 0, subj index -1

                temp_data = np.array(np.load(self.data_path_nsd / f'{subj:02d}_label.npy'))  # fmri -> image idx
                label_dict = {}  # reverse temp_data: image idx -> fmri
                for fmri_idx in range(len(temp_data)):
                    label_dict[temp_data[fmri_idx]] = fmri_idx

                label_dict_new = {}
                for img_idx in temp_image:
                    if img_idx in label_dict:
                        label_dict_new[img_idx] = label_dict[img_idx]

                self.unique_images[subj] = temp_image
                self.unique_trial_dict[subj] = label_dict_new

            self.max_length = max(len(self.unique_images[subj]) for subj in self.unique_images)
            print(f"\033[92m {self.max_length} \033[0m")

            self.unique_images = self.pad_data(self.unique_images, self.max_length)



        else:
            for subj_idx in self.sub:
                images_cur_subj = np.array(np.load(self.data_path_nsd / f'{subj_idx:02d}_label.npy'))
                # print(f"\033[92m {images_cur_subj} \033[0m")
                self.images_all_subjs[subj_idx] = images_cur_subj

            val_image_idx = np.load(shared_trial, allow_pickle=True)  # val label for one subject
            val_image_idx = np.array(val_image_idx)

            # print(f"\033[92m {len(val_image_idx)} \033[0m")
            # validate using part of the val dataset
            val_image_num = int(val_data_fraction * len(val_image_idx))
            val_image_idx = val_image_idx[:val_image_num]

            # print(f"\033[95m val_image {len(val_image_idx)} \033[0m")

            temp = []
            last = 0
            for subj_idx in self.sub:
                val_label = []
                for idx in val_image_idx:
                    where_result = np.where(self.images_all_subjs[subj_idx] == idx)[0]
                    if len(where_result) > 0:  # Check if the condition is met
                        val_label.append(idx)
                temp.append(val_label)
                last += len(val_label)

                image_to_fmri_idx_dict = {}
                for fmri_idx in range(len(self.images_all_subjs[subj_idx])):
                    image_to_fmri_idx_dict[self.images_all_subjs[subj_idx][fmri_idx]] = fmri_idx

                # print(f"\033[92m {len(image_to_fmri_idx_dict)} \033[0m")
                self.shared_trial_dict[subj_idx] = image_to_fmri_idx_dict
            self.val_label = np.concatenate(temp)


        if phase == 'pretrain' and phase == 'train':
            self.is_train = True
        else:
            self.is_train = False
        print(f'Data length:{self.__len__()}')

    def pad_data(self, images, max_length):
        for subj in images:
            current_length = len(images[subj])

            # print(f"\033[96m {current_length} \033[0m")

            if current_length < max_length:
                padding_length = max_length - current_length

                original_imgs = images[subj]
                repeat_times = (padding_length // current_length) + 1
                padding_imgs = np.tile(original_imgs, repeat_times)[:padding_length]  
                images[subj] = np.concatenate([original_imgs, padding_imgs])


        return images


    def __getitem__(self, i):
        if self.phase == "pretrain":
            # =====================================================================================
            # Shared trials
            # =====================================================================================
            image_name = list(self.image_subj_fmri_dict.keys())[i]  # image index

            # print(f"\033[92m image_name {image_name} \033[0m")

            fMRIs_shared = []
            subj_label_shared = []


            for sub_idx in self.sub:
                fMRI_idxs = self.image_subj_fmri_dict[image_name][sub_idx]  # fmri index

                # print(f"\033[94m fMRI_idxs {fMRI_idxs} \033[0m")
                fMRI_idx = random.choice(fMRI_idxs)
                # print(f"\033[91m fMRI_idx {fMRI_idx} \033[0m")

                if sub_idx <= 3:
                    fMRI = np.load(self.data_path / f'{sub_idx:02d}_norm/surf_{fMRI_idx:06d}.npy')
                else:
                    fMRI = np.load(self.data_path_god / f'{sub_idx-3:02d}_norm/surf_{fMRI_idx:06d}.npy')
                fMRIs_shared.append(fMRI)
                subj_label_shared.append(sub_idx)
            fMRIs_shared = np.stack(fMRIs_shared)
            fMRIs_shared = torch.from_numpy(fMRIs_shared)[None]

            subj_label_shared = np.stack(subj_label_shared)
            subj_label_shared = torch.from_numpy(subj_label_shared)

            image_filename_s = os.path.join(self.train_images_path, f'{image_name}.JPEG')
            natural_image_s = Image.open(image_filename_s).convert('RGB')
            gt_image_s = self.image_transform(natural_image_s)


            cls_identifier = image_name.split('_')[0]
            annots_s = self.cap_label[cls_identifier]
            random_caption_s = random.choice(annots_s)
            return {
                "fMRIs": fMRIs_shared, "subj_lbl": subj_label_shared, "txt": random_caption_s, "gt_image": gt_image_s,
            }

        elif self.phase == "train":
            # =====================================================================================
            # Unique trials
            # =====================================================================================
            fMRIs_unique = []
            subj_label_unique = []

            gt_image_unique = []
            random_caption_unique = []
            for sub_idx in self.sub_nsd:
                idx_u = int(self.unique_images[sub_idx][i])  # image index
                fMRI_idx = self.unique_trial_dict[sub_idx][idx_u]  # fmri index
                fMRI = np.load(self.data_path_nsd / f'{sub_idx:02d}_norm/surf_{fMRI_idx:06d}.npy')
                fMRIs_unique.append(fMRI)

                image_filename_u = os.path.join(self.test_images_path, f'image_{idx_u:06d}.png')
                natural_image_u = Image.open(image_filename_u)
                gt_image_u = self.image_transform(natural_image_u)
                gt_image_unique.append(gt_image_u)

                annots_u = self.test_cap_label[idx_u]
                caption_u = list(annots_u[annots_u != ''])
                random_caption_u = random.choice(caption_u)
                random_caption_unique.append(random_caption_u)

                subj_label_unique.append(sub_idx+7)

            fMRIs_unique = np.stack(fMRIs_unique)
            subj_label_unique = np.stack(subj_label_unique)
            gt_image_unique = np.stack(gt_image_unique)

            fMRIs_unique = torch.from_numpy(fMRIs_unique)[None]
            subj_label_unique = torch.from_numpy(subj_label_unique)


            return {
                "fMRIs": fMRIs_unique, "subj_lbl": subj_label_unique, "txt": random_caption_unique, "gt_image": gt_image_unique,
            }


        else:
            idx = int(self.val_label[i])
            sub_idx = self.sub[0]
            fMRI_idx = self.shared_trial_dict[sub_idx][idx]  # fmri index

            fMRI = np.load(self.data_path_nsd / f'{self.sub[0]:02d}_norm/surf_{fMRI_idx:06d}.npy')

            fMRIs = fMRI[None]
            fMRIs = torch.from_numpy(fMRIs)[None]
            # print(f"\033[91m {surface.shape} \033[0m")

            # gt images
            image_filename = os.path.join(self.test_images_path, f'image_{idx:06d}.png')
            natural_image = Image.open(image_filename).convert('RGB')
            gt_image = self.image_transform(natural_image)
            # gt_image = rearrange(inp_img, 'c h w -> h w c')
            # print(f"\033[93m {gt_image.shape} \033[0m")

            # coco caption
            annots = self.test_cap_label[idx]
            caption = list(annots[annots != ''])
            # print(f"\033[91m {caption} \033[0m")
            random_caption = random.choice(caption)


            return {
                "fMRIs": fMRIs, "txt": random_caption,
                "gt_image": gt_image,
            }

    def __len__(self):
        if self.phase == "pretrain":
            return len(self.image_subj_fmri_dict.keys())
        elif self.phase == "train":
            return self.max_length
        else:
            return len(self.val_label)

