import os, numpy as np, torch, sys, pytorch_lightning as pl, warnings
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2BrainData = path2project + 'data/brain_data/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/


def scan_composition_text(lr, rl):
    if lr and rl:  # or check if
        scan_dir = 'ave'
    elif lr:
        scan_dir = 'lr'
    elif rl:
        scan_dir = 'rl'
    else:
        scan_dir = None
    return scan_dir


def parse_scandir(exist_LR, exist_RL):
    if exist_LR and exist_RL:
        scan_dir = 'LR+RL'
    elif exist_LR:
        scan_dir = 'LR'
    elif exist_RL:
        scan_dir = 'RL'
    else:
        raise ValueError('Must have some scan dir')
    return scan_dir


def parse_task(exist_REST1, exist_REST2):
    if exist_REST1 and exist_REST2:
        task = '1+2'
    elif exist_REST1:
        task = '1'
    elif exist_REST2:
        task = '2'
    else:
        raise ValueError('Must have some task')
    return task


def find_scandir_and_task(which_scans_exist, scan_names, fc_construction='concat_all_timeseries', task=None, scan_dir=None):
    # does the fc contain and LR timeseries, RL timeseries, or both?
    # does the fc contain and REST1 timeseries, REST2 timeseries, or both?
    # scan_names = ['REST1_LR', 'REST1_RL', 'REST2_LR', 'REST2_RL']

    if fc_construction == 'concat_all_timeseries':
        exist_LR, exist_RL, exist_REST1, exist_REST2 = False, False, False, False

        for scan in scan_names:
            # If the scan exists, take an accounting of which one it is
            if which_scans_exist[scan]:
                if 'LR' in scan: exist_LR = True
                if 'RL' in scan: exist_RL = True
                if 'REST1' in scan: exist_REST1 = True
                if 'REST2' in scan: exist_REST2 = True

        scan_dir = parse_scandir(exist_LR, exist_RL)
        task = parse_task(exist_REST1, exist_REST2)

        return scan_dir, task

    elif fc_construction == 'concat_task_timeseries':
        exist_LR, exist_RL = False, False
        for scan in scan_names:
            if which_scans_exist[scan] and (task in scan):
                if 'LR' in scan: exist_LR = True
                if 'RL' in scan: exist_RL = True

        if not (exist_LR or exist_RL):
            print('error, must be one')

        scan_dir = parse_scandir(exist_LR, exist_RL)
        return scan_dir, task

    elif fc_construction == 'concat_scandir_timeseries':
        exist_REST1, exist_REST2 = False, False
        for scan in scan_names:
            if which_scans_exist[scan] and (scan_dir in scan):
                if 'REST1' in scan: exist_REST1 = True
                if 'REST2' in scan: exist_REST2 = True

        if not (exist_REST1 or exist_REST2):
            print('error, must be one')
        task = parse_task(exist_REST1, exist_REST2)
        return scan_dir, task


def unpack_individual(subject_data_struct_list, scan_combination, scan_names, key='individual'):
    assert scan_combination in ['mean']
    fcs, scs, non_repeat_scs = [], [], []
    subject_ids, scan_dirs, tasks = [], [], []

    if scan_combination == 'mean':
        for s in subject_data_struct_list:
            task, scan_dir, scans = "", "", []

            for scan in scan_names:
                if s['which_scans_exist'][scan]: # we should never hit NaN fc now
                    fc = torch.tensor(s['fcs'][key].item()[scan].item())
                    scans.append(fc.repeat(1, 1, 1))
                    scan_dir += ('LR|' if ('LR' in scan) else 'RL|')
                    task += '1|' if ('REST1' in scan) else '2|'
            if any(s['which_scans_exist'].item()):
                scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                subject_ids.append(s['subject_id'])
                fc_mean = torch.cat(scans, dim=0).mean(dim=0).repeat(1, 1, 1)
                fcs.append(fc_mean)
                scan_dirs.append(scan_dir[:-1])
                tasks.append(task[:-1])
    else:
        for s in subject_data_struct_list:
            for scan in scan_names:
                if s['which_scans_exist'][scan]: # we should never hit NaN fc now
                    scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                    subject_ids.append(s['subject_id'])
                    fc = torch.tensor(s['fcs'][key].item()[scan].item())
                    fcs.append(fc.repeat(1, 1, 1))
                    scan_dirs.append('LR') if 'LR' in scan else scan_dirs.append('RL')
                    tasks.append('1') if 'REST1' in scan else tasks.append('2')

    return fcs, scs, subject_ids, scan_dirs, tasks


def unpack_task_concat(subject_data_struct_list, scan_combination, scan_names, key, grouped_fc_names, fc_construction):
    fcs, scs, non_repeat_scs = [], [], []
    subject_ids, scan_dirs, tasks = [], [], []

    if scan_combination == 'mean':
        for s in subject_data_struct_list:
            which_task_exist = {}
            which_task_exist['REST1'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST1_RL'].item()
            which_task_exist['REST2'] = s['which_scans_exist']['REST2_LR'].item() or s['which_scans_exist']['REST2_RL'].item()
            task, scan_dir, scans = "", "", []

            for fc_name in grouped_fc_names:
                if which_task_exist[fc_name]:  # only proceed if we have something, if not will get error
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    scans.append(fc.repeat(1, 1, 1))
                    scan_dir_, task_ = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, task=fc_name)
                    scan_dir += scan_dir_ + '|'
                    task += task_ + '|'
            if any(which_task_exist.values()):
                scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))
                fc_mean = torch.cat(scans, dim=0).mean(dim=0).repeat(1, 1, 1)
                fcs.append(fc_mean.repeat(1, 1, 1))
                scan_dirs.append(scan_dir[:-1])
                tasks.append(task[:-1])
                subject_ids.append(s['subject_id'])
    else:
        for s in subject_data_struct_list:
            which_task_exist = {}
            which_task_exist['REST1'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST1_RL'].item()
            which_task_exist['REST2'] = s['which_scans_exist']['REST2_LR'].item() or s['which_scans_exist']['REST2_RL'].item()
            for fc_name in grouped_fc_names:
                if which_task_exist[fc_name]:  # only proceed if we have something, if not will get error
                    scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                    subject_ids.append(s['subject_id'])
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    fcs.append(fc.repeat(1, 1, 1))
                    scan_dir, task = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, task=fc_name)
                    scan_dirs.append(scan_dir)
                    tasks.append(task)

    return fcs, scs, subject_ids, scan_dirs, tasks


def unpack_scandir_concat(subject_data_struct_list, scan_combination, scan_names, key, grouped_fc_names, fc_construction):
    fcs, scs, non_repeat_scs = [], [], []
    subject_ids, scan_dirs, tasks = [], [], []

    if scan_combination == 'mean':
        for s in subject_data_struct_list:
            which_scandir_exist = {}
            which_scandir_exist['LR'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST2_LR'].item()
            which_scandir_exist['RL'] = s['which_scans_exist']['REST1_RL'].item() or s['which_scans_exist']['REST2_RL'].item()
            task, scan_dir, scans = "", "", []

            for fc_name in grouped_fc_names:
                if which_scandir_exist[fc_name]:  # only proceed if we have something, if not will get error
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    scans.append(fc.repeat(1, 1, 1))
                    scan_dir_, task_ = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, scan_dir=fc_name)
                    scan_dir += scan_dir_ + '|'
                    task += task_ + '|'
            if any(which_scandir_exist.values()):
                scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))
                fc_mean = torch.cat(scans, dim=0).mean(dim=0).repeat(1, 1, 1)
                fcs.append(fc_mean.repeat(1, 1, 1))
                scan_dirs.append(scan_dir[:-1])
                tasks.append(task[:-1])
                subject_ids.append(s['subject_id'])
    else:
        for s in subject_data_struct_list:
            which_scandir_exist = {}
            which_scandir_exist['LR'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST2_LR'].item()
            which_scandir_exist['RL'] = s['which_scans_exist']['REST1_RL'].item() or s['which_scans_exist']['REST2_RL'].item()
            for fc_name in grouped_fc_names:
                if which_scandir_exist[fc_name]:  # only proceed if we have something, if not will get error
                    scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                    subject_ids.append(s['subject_id'])
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    fcs.append(fc.repeat(1, 1, 1))
                    scan_dir, task = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, scan_dir=fc_name)
                    scan_dirs.append(scan_dir)
                    tasks.append(task)

    return fcs, scs, subject_ids, scan_dirs, tasks


# util funcs
def matlab_struct_to_tensors(subject_data_struct_list, fc_construction='concat_all_timeseries', scan_combination='mean', datatype=torch.float32):
    # we require tensor of fcs and scs.
    # Map from input structure to tensors
    # construct dataset of inputs (fcs) and labels (scs). Each patient can have multiple
    #  fcs (REST1/2 and scan direction LR vs RL). Must also output list of unique subject id's
    #  to properly split dataset for train/validation/test

    assert fc_construction in ['individual_timeseries', 'concat_all_timeseries', 'concat_task_timeseries', 'concat_scandir_timeseries']

    scan_names = ['REST1_LR', 'REST1_RL', 'REST2_LR', 'REST2_RL']
    if fc_construction == 'individual_timeseries':
        fcs, scs, subject_ids, scan_dirs, tasks = \
            unpack_individual(subject_data_struct_list, scan_combination=scan_combination, scan_names=scan_names)

    elif fc_construction == 'concat_all_timeseries':
        key = 'fcs_all'
        fcs, scs, non_repeat_scs = [], [], []
        subject_ids, scan_dirs, tasks = [], [], []
        for s in subject_data_struct_list:
            scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
            subject_ids.append(s['subject_id'])
            fc = torch.tensor(s['fcs'][key].item().item()[0])
            fcs.append(fc.repeat(1, 1, 1))
            scan_dir, task = find_scandir_and_task(s['which_scans_exist'], scan_names)

            scan_dirs.append(scan_dir)
            tasks.append(task)

    elif fc_construction == 'concat_task_timeseries':
        key = 'task_grouped'
        grouped_fc_names = ['REST1', 'REST2']
        fcs, scs, subject_ids, scan_dirs, tasks = \
            unpack_task_concat(subject_data_struct_list, scan_combination=scan_combination, scan_names=scan_names,
                               key=key, grouped_fc_names=grouped_fc_names, fc_construction=fc_construction)
    elif fc_construction == 'concat_scandir_timeseries':
        key = 'scandir_grouped'
        grouped_fc_names = ['LR', 'RL']
        fcs, scs, subject_ids, scan_dirs, tasks = \
            unpack_scandir_concat(subject_data_struct_list, scan_combination=scan_combination, scan_names=scan_names,
                                  key=key, grouped_fc_names=grouped_fc_names, fc_construction=fc_construction)
    else:
        raise ValueError(f'fc_construction {fc_construction} not recognized')

    assert len(fcs) == len(scs) == len(subject_ids) == len(scan_dirs) == len(tasks), f'all lens should be same: fcs {len(fcs)}, scs {len(scs)}, ids {len(subject_ids)}, sd {len(scan_dirs)} tasks {len(tasks)}'

    fcs_tensor = torch.cat(fcs, dim=0)
    scs_tensor = torch.cat(scs, dim=0)

    return fcs_tensor.to(datatype), scs_tensor.to(datatype), torch.tensor(subject_ids).to(torch.int32), \
           np.array(scan_dirs), np.array(tasks)
