from dataclasses import dataclass
import os
import glob

from pytorch_lightning.utilities.types import EVAL_DATALOADERS
import nibabel as nib
import scipy as sp
import numpy as np
import pandas as pd
from tqdm import tqdm
import h5py
from collections import defaultdict

from typing import Union, List, Iterable, Optional, Dict, Tuple, Any

import torch
from torch.utils.data import random_split, default_collate, ConcatDataset

from tango.integrations.torch import DataLoader
from tango.common import Lazy
from tango.common import Registrable
import torch_geometric as pyg
from torch_geometric.data import Data
import pytorch_lightning as pl
from .utils import expand_task_and_dataset


atlas_to_file = {
    "aal3v1": "AAL3v1_3mm.nii.gz",
    "scheafer400": "scheafer400_HCP_3mm.nii.gz",
    "shen268": "Shen268_HCP_3mm.nii.gz",
    "shen368": "Shen368_HCP_3mm.nii.gz",
}

atlas_to_number = {
    "aal3v1": 164,
    "scheafer400": 400,
    "shen268": 268,
    "shen368": 368,
    "carddock": 200,
    "raw": 54971,
    "difumo": 1024,
}

not_valid_files = ['3mm_R_sub-NDARINV2DLW1KXL_ses-baselineYear1Arm1_task-rsfMRI_run-20170429105507_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVJTAFGVNE_ses-baselineYear1Arm1_task-rsfMRI_run-20170306130844_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVMZP6BUUM_ses-baselineYear1Arm1_task-rsfMRI_run-20170217171849_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVBDGGPJYX_ses-baselineYear1Arm1_task-rsfMRI_run-20171002151140_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVFU4GY5DE_ses-baselineYear1Arm1_task-rsfMRI_run-20170608125600_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVGXEX8Z13_ses-baselineYear1Arm1_task-rsfMRI_run-20170318101059_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVBFJ6N0KJ_ses-baselineYear1Arm1_task-rsfMRI_run-20170909133010_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV9A111GA5_ses-baselineYear1Arm1_task-rsfMRI_run-20170909132528_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV6MB1G78W_ses-baselineYear1Arm1_task-rsfMRI_run-20170701100804_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVT2D0E2LE_ses-baselineYear1Arm1_task-rsfMRI_run-20171120131615_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVY6L80Y6V_ses-baselineYear1Arm1_task-rsfMRI_run-20170608132309_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV3XW4NJJV_ses-baselineYear1Arm1_task-rsfMRI_run-20170322143848_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVNBU3UA23_ses-baselineYear1Arm1_task-rsfMRI_run-20170330122918_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV8R4TWUN1_ses-baselineYear1Arm1_task-rsfMRI_run-20171010155721_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVLA4MCMBY_ses-baselineYear1Arm1_task-rsfMRI_run-20170608152845_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVK9YR13JW_ses-baselineYear1Arm1_task-rsfMRI_run-20171204101035_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVD06ADJ09_ses-baselineYear1Arm1_task-rsfMRI_run-20171026170021_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVKFJWT11B_ses-baselineYear1Arm1_task-rsfMRI_run-20170829161254_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV4DVR9EVD_ses-baselineYear1Arm1_task-rsfMRI_run-20171020102206_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVFCVZ66WZ_ses-baselineYear1Arm1_task-rsfMRI_run-20161112113932_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVP1WDXV8E_ses-baselineYear1Arm1_task-rsfMRI_run-20171005175601_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV0UD3BJFR_ses-baselineYear1Arm1_task-rsfMRI_run-20170603101603_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVK9RPEUC1_ses-baselineYear1Arm1_task-rsfMRI_run-20170412161920_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVKRG5L5EG_ses-baselineYear1Arm1_task-rsfMRI_run-20170909105002_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVVUD7L4FZ_ses-baselineYear1Arm1_task-rsfMRI_run-20170719113610_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVCB2Z3BZD_ses-baselineYear1Arm1_task-rsfMRI_run-20171125131715_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVAH92KG8X_ses-baselineYear1Arm1_task-rsfMRI_run-20170320101258_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVMH4PJ6L9_ses-baselineYear1Arm1_task-rsfMRI_run-20170602162159_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVTDLKZKEB_ses-baselineYear1Arm1_task-rsfMRI_run-20161128161844_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVZP36MY9W_ses-baselineYear1Arm1_task-rsfMRI_run-20170928114738_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVKX88859E_ses-baselineYear1Arm1_task-rsfMRI_run-20170816091822_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVXU9WL7P7_ses-baselineYear1Arm1_task-rsfMRI_run-20170610105159_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVFVP5BP99_ses-baselineYear1Arm1_task-rsfMRI_run-20170715143939_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVTVGCT5T2_ses-baselineYear1Arm1_task-rsfMRI_run-20170811151635_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVMTVTFDC0_ses-baselineYear1Arm1_task-rsfMRI_run-20161222134844_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVZVD13ZMG_ses-baselineYear1Arm1_task-rsfMRI_run-20161223124531_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVYFCLVGKY_ses-baselineYear1Arm1_task-rsfMRI_run-20171024172647_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV8TUJVDPE_ses-baselineYear1Arm1_task-rsfMRI_run-20170523122039_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVAU7FW44R_ses-baselineYear1Arm1_task-rsfMRI_run-20170507132009_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVEE6R9ATH_ses-baselineYear1Arm1_task-rsfMRI_run-20170726100830_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVPWL6YB2C_ses-baselineYear1Arm1_task-rsfMRI_run-20170225160616_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVB3VHJDXU_ses-baselineYear1Arm1_task-rsfMRI_run-20170430135249_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVUY636PFH_ses-baselineYear1Arm1_task-rsfMRI_run-20170331152931_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVNTB9DCL8_ses-baselineYear1Arm1_task-rsfMRI_run-20170327115453_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV53AA4BL6_ses-baselineYear1Arm1_task-rsfMRI_run-20171128143114_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV4Z38T1RY_ses-baselineYear1Arm1_task-rsfMRI_run-20170721125132_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVUMHGTLW8_ses-baselineYear1Arm1_task-rsfMRI_run-20170627114224_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVF4HCTG43_ses-baselineYear1Arm1_task-rsfMRI_run-20170104153704_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVHLKBLZ33_ses-baselineYear1Arm1_task-rsfMRI_run-20170713121400_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVL2RJ9HXF_ses-baselineYear1Arm1_task-rsfMRI_run-20171007152448_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVT9RH5R7B_ses-baselineYear1Arm1_task-rsfMRI_run-20170228101458_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV5VJVZMV6_ses-baselineYear1Arm1_task-rsfMRI_run-20170330163133_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV26V2CZ6Z_ses-baselineYear1Arm1_task-rsfMRI_run-20170422125851_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVJ1DZJB1L_ses-baselineYear1Arm1_task-rsfMRI_run-20170712094535_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV5GB3TA25_ses-baselineYear1Arm1_task-rsfMRI_run-20170527101135_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVZ62Y3LZB_ses-baselineYear1Arm1_task-rsfMRI_run-20171108095722_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV6BN8ZZN1_ses-baselineYear1Arm1_task-rsfMRI_run-20170526133716_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV0JWEE23L_ses-baselineYear1Arm1_task-rsfMRI_run-20170603160221_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV9HP92RX0_ses-baselineYear1Arm1_task-rsfMRI_run-20170516142324_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVCVNFXR7V_ses-baselineYear1Arm1_task-rsfMRI_run-20171014141200_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVT4BR7L9V_ses-baselineYear1Arm1_task-rsfMRI_run-20170711144948_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVXJ2454AG_ses-baselineYear1Arm1_task-rsfMRI_run-20170623092818_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVA8VDXCHY_ses-baselineYear1Arm1_task-rsfMRI_run-20171104152909_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVCFK3HDCF_ses-baselineYear1Arm1_task-rsfMRI_run-20171027163312_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV0FBX6KWX_ses-baselineYear1Arm1_task-rsfMRI_run-20170803104511_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINV5DXMXDFV_ses-baselineYear1Arm1_task-rsfMRI_run-20171014131443_bold_preprocessed_MNI',
 '3mm_R_sub-NDARINVN4H2VUF6_ses-baselineYear1Arm1_task-rsfMRI_run-20170817091242_bold_preprocessed_MNI']

no_label_subjects = ['NDARINV0LN1KD13',
 'NDARINV0P4XZMZA',
 'NDARINV0PHTY15N',
 'NDARINV0U23M45G',
 'NDARINV2T6WFDFJ',
 'NDARINV40YFT2B2',
 'NDARINV486GVBCD',
 'NDARINV4DVGGJE9',
 'NDARINV5LM4TFDU',
 'NDARINV7ATXXJUG',
 'NDARINV93J5H7BU',
 'NDARINVC33HFY1D',
 'NDARINVDDKXYJ2D',
 'NDARINVFB2YC0F5',
 'NDARINVG5PC310E',
 'NDARINVG6WRD4X4',
 'NDARINVJDVAX1CU',
 'NDARINVNRVAZLB0',
 'NDARINVPYJ78DEY',
 'NDARINVR2R52JP7',
 'NDARINVT1VTW1HM',
 'NDARINVT68R8W1J',
 'NDARINVTYZ21CK1',
 'NDARINVULT8KHWH',
 'NDARINVWZ9XXBWY',
 'NDARINVX3PYJZPC',
 'NDARINVZJYGP0WE']

def get_pearson_correlation(x, y, epsilon=1e-10):
    pass

def read_func(file):
    data = torch.load(file).float()
    return data

def extract_NDAR(fn: str):
    return fn.split("_")[2].split("-")[1]

def normalize(data, eps=1e-7):
    data = data / (torch.mean(torch.abs(data)) + eps)
    return data

def downsample(x, f):
    """
    x - [bs, seq_len]
    f - downsample, int
    """
    bs, seq_len = x.shape
    xp = np.concatenate([x, np.nan + np.zeros((bs, -seq_len % f,))], axis=1)
    xp = np.nanmean(xp.reshape(-1, f), axis=-1)
    
    xp = xp.reshape(bs, np.ceil(seq_len / f).astype("int"))
    
    return xp
    

complete_edge_index = None


class Dataset(torch.utils.data.Dataset, Registrable):
    ...
    

@Dataset.register("multi_task")
class MultiTaskDataset(Dataset):
    def __init__(self, datasets: Dict[str, Dataset]):
        self.datasets: dict = datasets
    
    def __len__(self):
        return sum([len(d) for d in self.datasets.values()])

    def __getitem__(self, idx: Tuple[str, int]):
        task, idx = idx
        dataset = self.datasets[task]
        return dataset[idx]


class MultiTaskBatchSampler(torch.utils.data.Sampler):
    def __init__(self, data_source: Dataset, batch_size: int, shuffle: bool = False, drop_last: bool = False):
        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __iter__(self):
        def chunks(lst, n):
            """Yield successive n-sized chunks from lst."""
            for i in range(0, len(lst), n):
                yield lst[i:i + n]

        all_indices: List[List[Tuple[str, int]]] = []
        for task, dataset in self.data_source.datasets.items():
            indices: List[int] = list(range(len(dataset)))
            if self.shuffle:
                np.random.shuffle(indices)
            for chunk in chunks(indices, self.batch_size):
                all_indices.append([(task, c) for c in chunk])
        if self.shuffle:
            np.random.shuffle(all_indices)
        
        yield from all_indices

    def __len__(self):
        if self.drop_last:
            return len(self.data_source) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.data_source) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]


@Dataset.register("brain")
class BrainDataset(Dataset):
    def __init__(self, file, downsample: int = 1, truncate: Optional[int] = None, split: Optional[str] = None, feature: str = "raw", split_method: str = "predefine", split_id: Optional[int] = None):
        self.file = file
        self.in_memory = True
        self.downsample = downsample
        self.truncate = truncate
        self.split = split
        self.feature = feature
        self.split_method = split_method
        self.split_id = split_id
        # if atlas == "raw":
        #     self._init_proxy_atlas("aal3v1")
        
        data_path = self.file.filename
        if "ABCD1" in data_path:
            self.dataset_name = "ABCD1"
        elif "ABCD2" in data_path:
            self.dataset_name = "ABCD2"
        if "Rest" in data_path:
            self.task = "Rest"
        elif "nBack" in data_path:
            self.task = "nBack"
        elif "MID" in data_path:
            self.task = "MID"
        elif "SST" in data_path:
            self.task = "SST"
        assert hasattr(self, "dataset_name") and hasattr(self, "task")

        self.dataset = self._read_data()
        
    def __len__(self):
        return len(self.dataset)
    
    # def _init_proxy_atlas(self, atlas_name):
    #     atlas_path = os.path.join("data", "atlas", atlas_to_file[atlas_name])
    #     self.proxy_atlas = nib.load(atlas_path).get_fdata().astype('int')
    #     self.proxy_atlas_nonzeros = self.proxy_atlas.nonzero()
    
    def augment(self, x):
            # x = sp.interpolate.interp1d(range(x.shape[-1]), x)(np.arange(0, x.shape[-1], self.downsample))
        if self.truncate is not None and self.split != "predict":
            start = torch.randint(0, x.size(-1) - self.truncate, size=(1,)).item()
            x = x[..., start:start+self.truncate]
        return x

    def process(self, data):
        x = data["x"][()]
        if self.downsample > 1:
            x = downsample(x, self.downsample)
        x = torch.from_numpy(x).float()
        y = torch.tensor(data['y'][()]).float()
        return x, (y - 60) / (121 - 60)
        
    def __getitem__(self, idx):
        if self.in_memory:
            x, y = self.dataset[idx]
        else:
            idx = self.idx_to_filename[idx]
            x, y = self.process(self.dataset[idx])
        x = self.augment(x)
        # seq = (x[...,:-2:3] + x[...,1:-1:3] + x[...,2::3])/3
        if self.feature == "correlation":
            x = torch.corrcoef(x)
            x = torch.nan_to_num(x, 0)
        elif self.feature == "raw":
            ...
        else:
            raise NotImplementedError

        meta = {
            "subject": idx,
            "dataset_name": self.dataset_name,
            "task": self.task,
        }
        return_data = {
            'x': x,
            'y': y,
            'meta': meta
        }
        return return_data
    
    def _read_data(self):
        # files = [line.rstrip('\n') for line in open(os.path.join("data", self.dataset, self.split+'.split'))]
        # files = [file for file in files if file not in not_valid_files]
        # if self.args.downstream is not None:
            # files = [file for file in files if extract_NDAR(os.path.basename(file)) not in no_label_subjects]
        data = self.file
        
        if self.split == "predict":
            file_names = []
            for key, value in data.items():
                if "y" in value:
                    file_names.append(key)
        else:
            if self.split_id is None:
                train_test_path = os.path.join(os.path.split(self.file.filename)[0], f"{self.split}.split")
            else:
                train_test_path = os.path.join(os.path.split(self.file.filename)[0], f"{self.split}{self.split_id}.split")
            file_names = open(train_test_path, "r").read().split("\n")
            
        file_names.sort()
        self.idx_to_filename = file_names
          
        if self.in_memory:
            data = [self.process(self.file[fn]) for fn in tqdm(file_names)]
        
        print(f"# instances of {self.dataset_name} {self.task}: ", len(file_names))
        
        return data

@Dataset.register("fake")
class FakeDataset(BrainDataset):
    def _read_data(self):
        data = []
        for _ in range(2000):
            data.append((
                torch.randn(164, 375, dtype=torch.float),
                torch.randn(1, dtype=torch.float).squeeze()
            ))
        return data


class DataModule(pl.LightningDataModule, Registrable):
    CACHEABLE = False
    DETERMINISTC = False
    
    
@DataModule.register("brain")
class BrainDataModule(pl.LightningDataModule):
    def __init__(self, file_path: str, dataset: Lazy[Dataset], batch_size: int, num_workers: int, split_method: str = "predefine") -> None:
        super().__init__()
        self.file_path = file_path

        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_method = split_method
        
        self.n_region = next(iter(h5py.File(self.file_path.replace('[TSK]', 'Rest').replace('[DAT]', "ABCD1"), "r").values()))['x'][()].shape[-2]
    
    def setup(self, stage: str) -> None:
        if stage == "fit":
            for split in ['train', 'val', 'test']:
                datasets = defaultdict(list)
                for file_path in expand_task_and_dataset(self.file_path):
                    h5 = h5py.File(file_path, "r")
                    single = self.dataset.construct(file=h5, split=split)
                    datasets[single.task].append(single)
                datasets = dict(datasets)
                for task, datasets_of_task in datasets.items():
                    datasets[task] = ConcatDataset(datasets_of_task)
                dataset = MultiTaskDataset(datasets)
                setattr(self, split, dataset)

        # self.val = self.train
        # assert self.seq_len == data.size(-1)
        elif stage == "predict":
            datasets = []
            for file_path in expand_task_and_dataset(self.file_path):
                h5 = h5py.File(file_path, "r")
                datasets.append(self.dataset.construct(file=h5, split="predict"))
            self.predict = datasets
        else:
            raise

    def train_dataloader(self):
        # return DataLoader(self.train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
        batch_sampler = MultiTaskBatchSampler(self.train, batch_size=self.batch_size, shuffle=True)
        return DataLoader(self.train, batch_sampler=batch_sampler, num_workers=self.num_workers)
    
    def val_dataloader(self):
        # return DataLoader(self.val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        batch_sampler = MultiTaskBatchSampler(self.val, batch_size=self.batch_size, shuffle=False)
        return DataLoader(self.val, batch_sampler=batch_sampler, num_workers=self.num_workers)
    
    def predict_dataloader(self):
        return [DataLoader(data, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) for data in self.predict]
    