import copy
import pickle
import numpy as np
from utils import get_random_idx_split, get_ood_split, set_seed
import torch
import shutil
from tqdm import tqdm
from torch_geometric.data import Data, InMemoryDataset
import os.path as osp
from pathlib import Path
from pathlib import Path


class HEP_OOD_Shift(InMemoryDataset):

    def __init__(self, root, data_config, seed, pileup, tesla='2T'):
        self.tesla = tesla
        self.dataset_name = data_config['data_name']  # tau3mu or actstrack
        self.split = data_config['split']
        self.iid_split = data_config['iid_split']
        self.pos_features = data_config['pos_features']
        self.other_features = data_config['other_features']
        self.dataset_dir = Path(data_config['get_dataset_dir'])
        self.bkg_dir = self.dataset_dir / 'background'
        self.sig_dir = self.dataset_dir / f'{self.dataset_name}' / f'raw_{self.tesla}'
        self.seed = seed
        self.pileup = pileup
        super().__init__(root)
        self.data, self.slices, self.idx_split = torch.load(self.processed_paths[0])
        self.x_dim = self.data.x.shape[1]
        self.pos_dim = self.data.pos.shape[1]
        self.feature_type = data_config['feature_type']
        self.signal_class = 1

        if self.feature_type == 'only_pos':
            node_scalar_feat = self.pos_dim
        elif self.feature_type == 'only_x':
            node_scalar_feat = self.x_dim
        elif self.feature_type == 'only_ones':
            node_scalar_feat = 1
        else:
            assert self.feature_type == 'both_x_pos'
            node_scalar_feat = self.x_dim + self.pos_dim

        self.feat_info = {'node_categorical_feat': [], 'node_scalar_feat': node_scalar_feat}

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, f'processed_{self.tesla}')


class HEP_Pileup_Shift(HEP_OOD_Shift):
    def __init__(self, root, data_config, seed, pileup, pileup_ood, setting, restrict_TL_train=0, tesla='2T'):
        self.shift_type = 'pileup_shift'
        self.pileup_ood = pileup_ood
        self.restrict_TL_train = restrict_TL_train
        self.setting = setting  # option: "OOD", "DA", "TL"
        self.domain_splits = 5  # predefined parameter
        if self.setting == 'TL':
            self.setting = f'TL_#{restrict_TL_train}'
        super().__init__(root, data_config, seed, pileup, tesla)

    @property
    def processed_file_names(self):
        return [f'{self.shift_type}_{self.pileup}_{self.pileup_ood}_{self.setting}.pt']

    def process(self):

        def obtain_list(event_type, pileups):
            # event_type = 'bkg' or 'signal'
            base_dir = self.bkg_dir if event_type == 'bkg' else self.sig_dir
            file_path = base_dir / f'{event_type}_events_{self.tesla}_pileups_{pileups}.pkl'
            get_list = pickle.load(open(file_path, 'rb'))
            return get_list

        bkg_list = obtain_list('bkg', self.pileup)
        sig_list = obtain_list('signal', self.pileup)
        bkg_list_ood = obtain_list('bkg', self.pileup_ood)
        sig_list_ood = obtain_list('signal', self.pileup_ood)

        bkg_split = get_random_idx_split(len(bkg_list), self.iid_split, 0)
        sig_split = get_random_idx_split(len(sig_list), self.iid_split, 0)

        idx_split, Dataset = {}, []
        if self.setting.split('_')[0] in ["OOD", "DA"]:
            dataset_dict = {'train': [], 'iid_val': [], 'iid_test': [], 'ood_val': [], 'ood_test': []}
            dataset_dict_ = dict()
            bkg_min, bkg_max = get_size_extremum(bkg_list)
            sig_min, sig_max = get_size_extremum(sig_list)
            bkg_bins = np.linspace(start=bkg_min, stop=bkg_max, num=self.domain_splits + 1)
            sig_bins = np.linspace(start=sig_min, stop=sig_max, num=self.domain_splits + 1)

            # in-distribution dataset
            for item in ['train', 'iid_val', 'iid_test']:
                item_ = item.split('_')[1] if len(item.split('_')) > 1 else item
                for idx in bkg_split[item_]:
                    data = bkg_list[idx]
                    domain_id = np.digitize(data[0].shape[0], bkg_bins) - 1
                    dataset_dict[item].append(
                        build_data_object(data, signal=False, event_type='bkg', domain_id=domain_id))
                for idx in sig_split[item_]:
                    data = sig_list[idx]
                    domain_id = np.digitize(data[0].shape[0], sig_bins) - 1
                    dataset_dict[item].append(
                        build_data_object(data, signal=True,
                                          event_type='tau3mu' if self.dataset_name == 'tau3mu' else 'z2mu', domain_id=domain_id))

            # out-of-distribution dataset, with 2500 `bkg` and `sig` data respectively
            train_target_list = []
            data_num = 10000 if self.pileup_ood == 50 else 7700
            for idx in range(data_num):
                bkg_data_obj = build_data_object(bkg_list_ood[idx], signal=False, event_type='bkg')
                sig_data_obj = build_data_object(sig_list_ood[idx], signal=True, event_type='z2mu')
                train_target_list.append(bkg_data_obj)
                train_target_list.append(sig_data_obj)
                if 0 <= idx < 2500:
                    dataset_dict['ood_val'].append(bkg_data_obj)
                    dataset_dict['ood_val'].append(sig_data_obj)
                elif 2500 <= idx < 5000:
                    dataset_dict['ood_test'].append(bkg_data_obj)
                    dataset_dict['ood_test'].append(sig_data_obj)
            if self.setting == 'OOD':
                dataset_dict_ = dataset_dict
            elif self.setting == 'DA':
                dataset_dict_ = {'train_source': dataset_dict['train'],
                                 'train_target': train_target_list,
                                 'iid_val': dataset_dict['iid_val'],
                                 'iid_test': dataset_dict['iid_test'],
                                 'ood_val': dataset_dict['ood_val'],
                                 'ood_test': dataset_dict['ood_test']}
            for item in dataset_dict_.keys():
                idx_split[item] = [i + len(Dataset) for i in range(len(dataset_dict_[item]))]
                Dataset += dataset_dict_[item]

        elif self.setting.split('_')[0] == "TL":
            dataset_dict = {'train': [], 'val': [], 'test': []}
            for idx in range(2500):
                dataset_dict['val'].append(
                    build_data_object(bkg_list_ood[idx], signal=False, event_type='bkg'))
                dataset_dict['val'].append(
                    build_data_object(sig_list_ood[idx], signal=True,
                                      event_type='tau3mu' if self.dataset_name == 'tau3mu' else 'z2mu'))
            for idx in range(2500, 5000):
                dataset_dict['test'].append(
                    build_data_object(bkg_list_ood[idx], signal=False, event_type='bkg'))
                dataset_dict['test'].append(
                    build_data_object(sig_list_ood[idx], signal=True,
                                      event_type='tau3mu' if self.dataset_name == 'tau3mu' else 'z2mu'))
            for idx in range(5000, 5000+int(self.restrict_TL_train/2)):
                dataset_dict['train'].append(
                    build_data_object(bkg_list_ood[idx], signal=False, event_type='bkg'))
                dataset_dict['train'].append(
                    build_data_object(sig_list_ood[idx], signal=True,
                                      event_type='tau3mu' if self.dataset_name == 'tau3mu' else 'z2mu'))
            for item in dataset_dict.keys():
                idx_split[item] = [i + len(Dataset) for i in range(len(dataset_dict[item]))]
                Dataset += dataset_dict[item]

        data, slices = self.collate(Dataset)
        torch.save((data, slices, idx_split), self.processed_paths[0])


def build_data_object(event, signal, event_type, domain_id=-1):
    other_features = ['tt', 'tpx', 'tpy', 'tpz', 'te', 'deltapx', 'deltapy', 'deltapz', 'deltae']
    hits = event[0]
    signal_im = event[1]
    domain_id = torch.tensor(domain_id)
    y = torch.tensor(1).float().view(-1, 1) if signal else torch.tensor(0).float().view(-1, 1)

    hits['node_id'] = range(len(hits))
    pos = torch.tensor(hits[['tx', 'ty', 'tz']].to_numpy()).float()
    x = torch.tensor(hits[other_features].to_numpy()).float()
    node_label = torch.tensor(hits['node_label'].to_numpy()).float().view(-1)
    node_dir = torch.tensor(hits[['tpx', 'tpy', 'tpz']].to_numpy()).float()
    track_ids = torch.full((len(hits),), -1)  # indices which track the node belongs to
    num_tracks = 0
    all_ptcls = hits['particle_id'].unique()
    for ptcl in all_ptcls:
        track = hits[hits['particle_id'] == ptcl]
        track_ids[track['node_id'].to_numpy()] = num_tracks
        num_tracks += 1
    assert -1 not in track_ids
    return Data(x=x, pos=pos, y=y, node_label=node_label,
                node_dir=node_dir, num_tracks=num_tracks, track_ids=track_ids, signal_im=signal_im,
                event_type=event_type, domain_id=domain_id)


def get_size_extremum(dataset):
    size_list = []
    for data in dataset:
        size_list.append(data[0].shape[0])
    size_list = np.array(size_list)
    return size_list.min(), size_list.max()