import os
import pickle
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, TensorDataset, DataLoader
import scipy.stats
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse, to_dense_adj
from functools import partial
from torchvision.datasets.folder import  has_file_allowed_extension, is_image_file, IMG_EXTENSIONS, pil_loader, accimage_loader,default_loader
import torchvision.transforms as transforms


class Compcars(Dataset):
    def __init__(self, samples, transform=None, target_transform=None, device='cuda'):
        self.loader = default_loader
        self.samples = samples
        self.transform = transform
        self.target_transform = target_transform
        self.device = device

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)

        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

class StandardScaler:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def transform(self, data):
        return (data - self.mean) / self.std

    def inverse_transform(self, data):
        return (data * self.std) + self.mean


def load_pickle(pickle_file):
    try:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f)
    except UnicodeDecodeError as e:
        with open(pickle_file, 'rb') as f:
            pickle_data = pickle.load(f, encoding='latin1')
    except Exception as e:
        print('Unable to load data ', pickle_file, ':', e)
        raise
    return pickle_data


def load_pems_data(rootpath, name, adj_mx_name, ratio=1.0, one_node=False, one_sample=False):
    rootpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), rootpath)
    if ratio < 1.0:
        inductive = True
        partial_nodes = np.load(os.path.join(rootpath, 'sensor_graph', '{}_partial_nodes.npz'.format(name)), allow_pickle=True)
        partial_nodes = partial_nodes[str(ratio)]
        selected_nodes, _ = partial_nodes
    else:
        inductive = False

    adj_mx_path = os.path.join(rootpath, 'sensor_graph', adj_mx_name)
    _, _, adj_mx = load_pickle(adj_mx_path)
    adj_mx_ts = torch.from_numpy(adj_mx).float()
    if inductive:
        train_adj_mx_ts = adj_mx_ts[selected_nodes, :][:, selected_nodes]
        eval_adj_mx_ts = adj_mx_ts
    else:
        train_adj_mx_ts, eval_adj_mx_ts = adj_mx_ts, adj_mx_ts
    train_edge_index, train_edge_attr = dense_to_sparse(train_adj_mx_ts)
    eval_edge_index, eval_edge_attr = dense_to_sparse(eval_adj_mx_ts)

    datapath = os.path.join(rootpath, name)
    raw_data = {}
    for name in ['train', 'val', 'test']:
        raw_data[name] = np.load(os.path.join(datapath, '{}.npz'.format(name)))

    if inductive:
        selected_ts = torch.BoolTensor([False] * eval_adj_mx_ts.shape[0])
        selected_ts[selected_nodes] = True

    FEATURE_START, FEATURE_END = 0, 1
    ATTR_START, ATTR_END = 1, 2

    train_features = raw_data['train']['x'][..., FEATURE_START:FEATURE_END]
    if inductive:
        train_features = train_features[:, :, selected_nodes, :]
    train_features = train_features.reshape(-1, train_features.shape[-1])
    feature_scaler = StandardScaler(
        mean=train_features.mean(axis=0), std=train_features.std(axis=0)
    )
    attr_scaler = StandardScaler(
        mean=0, std=1
    )
    loaded_data = {
        'feature_scaler': feature_scaler,
        'attr_scaler': attr_scaler
    }

    for name in ['train', 'val', 'test']:
        x = feature_scaler.transform(raw_data[name]['x'][..., FEATURE_START:FEATURE_END])
        y = feature_scaler.transform(raw_data[name]['y'][..., FEATURE_START:FEATURE_END])
        x_attr = attr_scaler.transform(raw_data[name]['x'][..., ATTR_START:ATTR_END])
        y_attr = attr_scaler.transform(raw_data[name]['y'][..., ATTR_START:ATTR_END])

        # for debugging
        if one_node:
            x = x[:, :, 0:1, :]
            y = y[:, :, 0:1, :]
            x_attr = x_attr[:, :, 0:1, :]
            y_attr = y_attr[:, :, 0:1, :]
        if one_sample:
            x, y, x_attr, y_attr = x[0:1], y[0:1], x_attr[0:1], y_attr[0:1]

        data = {}
        if name == 'train':
            edge_index, edge_attr = train_edge_index, train_edge_attr
        else:
            edge_index, edge_attr = eval_edge_index, eval_edge_attr
        data.update(
            x=torch.from_numpy(x).float(), y=torch.from_numpy(y).float(),
            x_attr=torch.from_numpy(x_attr).float(),
            y_attr=torch.from_numpy(y_attr).float(),
            edge_index=edge_index, edge_attr=edge_attr
        )
        if name == 'train' and inductive:
            data.update(selected=selected_ts)
        loaded_data[name] = data

    return loaded_data

available_datasets = {
    'METR-LA': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl'),
    'METR-LA-onenode': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', one_node=True, one_sample=False),
    'METR-LA-onesample': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', one_node=False, one_sample=True),
    'METR-LA-0.25': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', ratio=0.25),
    'METR-LA-0.5': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', ratio=0.5),
    'METR-LA-0.75': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', ratio=0.75),
    'METR-LA-0.05': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', ratio=0.05),
    'METR-LA-0.9': partial(load_pems_data, rootpath='data/traffic/data',
        name='METR-LA', adj_mx_name='adj_mx.pkl', ratio=0.9),
    'PEMS-BAY': partial(load_pems_data, rootpath='data/traffic/data',
        name='PEMS-BAY', adj_mx_name='adj_mx_bay.pkl'),
    'PEMS-BAY-0.25': partial(load_pems_data, rootpath='data/traffic/data',
        name='PEMS-BAY', adj_mx_name='adj_mx_bay.pkl', ratio=0.25),
    'PEMS-BAY-0.5': partial(load_pems_data, rootpath='data/traffic/data',
        name='PEMS-BAY', adj_mx_name='adj_mx_bay.pkl', ratio=0.5),
    'PEMS-BAY-0.75': partial(load_pems_data, rootpath='data/traffic/data',
        name='PEMS-BAY', adj_mx_name='adj_mx_bay.pkl', ratio=0.75),
    'PEMS-BAY-0.05': partial(load_pems_data, rootpath='data/traffic/data',
        name='PEMS-BAY', adj_mx_name='adj_mx_bay.pkl', ratio=0.05),
    'PEMS-BAY-0.9': partial(load_pems_data, rootpath='data/traffic/data',
        name='PEMS-BAY', adj_mx_name='adj_mx_bay.pkl', ratio=0.9),
}

class BaseNodes:
    def __init__(
            self,
            data_name,
            batch_size=32,
            eval_batch_size=512,
            num_workers=0,
            device='cuda',
            data_dir='data'
    ):
        self.batch_size = batch_size
        self.eval_batch_size = eval_batch_size
        self.device = device
        self.domains = []
        self.train_loaders, self.val_loaders, self.test_loaders = [], [], []
        self.num_samples = []
        self.global_weights = {}
        self.global_buffers = {}
        self.weights = {}
        self.buffers = {}
        self.num_workers = num_workers
        self.data_dir = data_dir
        if data_name == 'DG15':
            self.load_dg_dataset("data/toy_d15_spiral_tight_boundary.pkl")
        elif data_name == 'DG60':
            self.load_dg_dataset("data/toy_d60_spiral.pkl")
        elif data_name == 'DG60_TRUNCNORM':
            self.load_dg_dataset("data/toy_d60_truncnorm.pkl")
        elif data_name == 'TPT48':
            self.load_tpt48_dataset()
        elif data_name == 'TPT48_TRUNCNORM':
            self.load_tpt48_dataset(is_hetero=True)
        elif 'METR-LA' in data_name or 'PEMS-BAY' in data_name:
            self.load_seq_dataset(data_name)
        elif data_name == 'CompCars':
            self.load_compcars(data_dir)
        else:
            raise Exception(f'Unknown dataset name {data_name}.')
        self.joined = np.zeros(self.n_nodes).astype(bool)

    def load_dg_dataset(self, data_path):
        data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), data_path)
        with open(data_path, 'rb') as f:
            data = pickle.load(f)
        self.domains = list(set(data['domain']))
        self.n_nodes = len(self.domains)
        self.A = torch.from_numpy(data['A']).long() + torch.eye(self.n_nodes)
        data['data'] = torch.from_numpy(data['data']).float()
        self.label_dim = len(set(data['label']))
        self.data_dim = data['data'].shape[-1]
        data['label'] = torch.from_numpy(data['label']).long()
        train_ratio = 0.8
        val_ratio = 0.1
        for d in self.domains:
            indices = np.where(data['domain'] == d)[0]
            num = len(indices)
            num_train = int(train_ratio*num)
            num_val = int(val_ratio*num)
            num_test = num - num_train - num_val
            domain_dataset = TensorDataset(data['data'][indices], data['label'][indices])
            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(domain_dataset, [num_train, num_val, num_test])
            self.train_loaders += [DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)]
            self.val_loaders += [DataLoader(val_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.test_loaders += [DataLoader(test_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.num_samples += [num_train]

    def load_tpt48_dataset(self, is_hetero=False):
        df = pd.read_fwf(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data/climdiv-tmpcst-v1.0.0-20200106'), 
            widths=[1, 2, 1, 2, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], 
            names=[
                "dont_care", "noaa_state_order", "divisional_number", "code", "year", 
                "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"
            ])
        df = df.loc[(df['noaa_state_order'] <= 48) & (df['dont_care']==0)]
        self.domains = list(set(df['noaa_state_order']))
        self.n_nodes = len(self.domains)
        self.data_dim = 6
        self.label_dim = 6
        self.A = torch.from_numpy(np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data/state_adj.npy'))) + torch.eye(self.n_nodes)
        train_ratio = 0.8
        if is_hetero:
            lower = 0.2
            upper = 0.8
            mu = 0.5
            sigma = 1
            train_ratios = scipy.stats.truncnorm.rvs(
                (lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma,size=self.n_nodes)
        for i, d in enumerate(self.domains):
            X_train = torch.FloatTensor(df[df['noaa_state_order']==d].loc[:,"Jan":].to_numpy())
            X_max = torch.max(X_train)
            X_min = torch.min(X_train)
            X_train_man_norm = (X_train - X_min)/(X_max - X_min)
            domain_dataset = torch.utils.data.TensorDataset(X_train_man_norm[:,:6], X_train_man_norm[:,6:])
            num = len(domain_dataset)
            if is_hetero:
                train_ratio = train_ratios[i]
            num_train = int(train_ratio*num)
            num_val = int((num - num_train) / 2)
            num_test = num - num_train - num_val
            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(domain_dataset, [num_train, num_val, num_test])
            self.train_loaders += [DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)]
            self.val_loaders += [DataLoader(val_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.test_loaders += [DataLoader(test_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.num_samples += [num_train]

    def load_tpt48_dataset_hard(self, is_hetero=False):
        df = pd.read_fwf(
            os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data/climdiv-tmpcst-v1.0.0-20200106'), 
            widths=[1, 2, 1, 2, 4, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7], 
            names=[
                "dont_care", "noaa_state_order", "divisional_number", "code", "year", 
                "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"
            ])
        df = df.loc[(df['noaa_state_order'] <= 48) & (df['dont_care']==0)]
        self.domains = list(set(df['noaa_state_order']))
        self.n_nodes = len(self.domains)
        self.data_dim = 6
        self.label_dim = 6
        self.A = torch.from_numpy(np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data/state_adj.npy'))) + torch.eye(self.n_nodes)
        train_ratio = 0.8
        if is_hetero:
            lower = 0.2
            upper = 0.8
            mu = 0.5
            sigma = 1
            train_ratios = scipy.stats.truncnorm.rvs(
                (lower-mu)/sigma,(upper-mu)/sigma,loc=mu,scale=sigma,size=self.n_nodes)
        for i, d in enumerate(self.domains):
            X_train = torch.FloatTensor(df[df['noaa_state_order']==d].loc[:,"Jan":].to_numpy()).flatten()
            X_max = torch.max(X_train)
            X_min = torch.min(X_train)
            X_train_man_norm = (X_train - X_min)/(X_max - X_min)
            start_idxs = np.random.choice(np.arange(len(X_train) - 12), 100, replace=False).astype(np.int32)
            X = torch.stack([X_train_man_norm[start_idx:start_idx+6] for start_idx in start_idxs])
            Y = torch.stack([X_train_man_norm[start_idx+6:start_idx+12] for start_idx in start_idxs])
            domain_dataset = torch.utils.data.TensorDataset(X, Y)
            num = len(domain_dataset)
            if is_hetero:
                train_ratio = train_ratios[i]
            num_train = int(train_ratio*num)
            num_val = int((num - num_train) / 2)
            num_test = num - num_train - num_val
            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(domain_dataset, [num_train, num_val, num_test])
            self.train_loaders += [DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)]
            self.val_loaders += [DataLoader(val_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.test_loaders += [DataLoader(test_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.num_samples += [num_train]

    def load_seq_dataset(self, data_name):
        data = available_datasets[data_name]()
        self.data_dim = data['train']['x'].shape[-1] + data['train']['x_attr'].shape[-1]
        self.label_dim = data['train']['y'].shape[-1]
        self.n_nodes = data['train']['x'].shape[2]
        self.domains = list(range(self.n_nodes))
        dataloader = {}
        for name in ['train', 'val', 'test']:
            data[name]['x'] = data[name]['x']
            data[name]['y'] = data[name]['y']
            data[name]['x_attr'] = data[name]['x_attr']
            data[name]['y_attr'] = data[name]['y_attr']
            dataloader[name] = []
            for i in range(self.n_nodes):
                dataset = TensorDataset(
                    data[name]['x'][:, :, i:i+1, :],
                    data[name]['y'][:, :, i:i+1, :],
                    data[name]['x_attr'][:, :, i:i+1, :],
                    data[name]['y_attr'][:, :, i:i+1, :]
                )
                dataloader[name] += [DataLoader(dataset, batch_size=self.batch_size if name == 'train' else self.eval_batch_size, shuffle=(name == 'train'), num_workers=self.num_workers)]
        self.train_loaders = dataloader['train']
        self.val_loaders = dataloader['val']
        self.test_loaders = dataloader['test']
        self.num_samples = [len(loader.dataset) for loader in dataloader['train']]
        self.A = to_dense_adj(data['test']['edge_index'], edge_attr=data['test']['edge_attr']).squeeze(0)
        self.train_edge_index = data['train']['edge_index']
        self.eval_edge_index = data['test']['edge_index']
        self.train_nodes = np.arange(self.n_nodes)
        if 'selected' in data['train']:
            self.train_nodes = self.train_nodes[data['train']['selected']]
    
    def load_compcars(self, data_dir):
        images = {}
        class_to_idx = {'1':0,'2':1,'3':2,'4':3}
        data_transform=transforms.Compose([
            transforms.Resize((224,224)),
                      transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
         ])
        with open(os.path.join('data', 'car-images.txt')) as f:
            lines = f.readlines()
        for l in lines:
            fname, domain, target = l.strip().split(' ')
            if has_file_allowed_extension(fname, IMG_EXTENSIONS):
                fname = os.path.join(data_dir, fname)
                item = (fname, class_to_idx[target])
                if domain not in images:
                    images[domain] = [item]
                else:
                    images[domain].append(item)
        self.domains = list(images.keys())
        self.n_nodes = len(self.domains)
        self.domains.sort()
        edges = []
        view_chain = ['1', '4', '3', '5', '2']
        for year in range(2009, 2015):
            for i, v in enumerate(view_chain):
                source_d = f'{year}-{v}'
                edges.append((source_d, source_d))
                start_year = max(year-1, 2009)
                end_year = min(year+1, 2014)
                start_i = max(i-1, 0)
                end_i = min(i+1, 4)
                for _y in range(start_year, end_year+1):
                    for _i in range(start_i, end_i):
                        target_d = f'{_y}-{view_chain[_i]}'
                        if target_d in self.domains:
                            edges.append((source_d, target_d))
        G = nx.Graph()
        G.add_nodes_from(self.domains)
        G.add_edges_from(edges)
        G.to_undirected()
        self.A = torch.from_numpy(nx.to_numpy_array(G))

        self.data_dim = (224,224)
        self.label_dim = 4
        train_ratio = 0.8
        preprocess_path = os.path.join(data_dir, 'compcars', 'preprocessed')
        os.makedirs(preprocess_path, exist_ok=True)
        for i, d in enumerate(self.domains):
            ckpt_path = os.path.join(preprocess_path, f"{d}.ckpt")
            if os.path.exists(ckpt_path):
                domain_dataset = torch.utils.data.TensorDataset(*torch.load(ckpt_path).values())
            else:
                domain_dataset = Compcars(images[d], transform=data_transform, device=self.device)
                x, y = zip(*[data for data in domain_dataset])
                torch.save({'data': torch.stack(x), 'target': torch.tensor(y, dtype=torch.long)}, ckpt_path)
            num = len(domain_dataset)
            num_train = int(train_ratio*num)
            num_val = int((num - num_train) / 2)
            num_test = num - num_train - num_val
            train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(domain_dataset, [num_train, num_val, num_test])
            self.train_loaders += [DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)]
            self.val_loaders += [DataLoader(val_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.test_loaders += [DataLoader(test_dataset, batch_size=self.eval_batch_size, num_workers=self.num_workers)]
            self.num_samples += [num_train]

    def __len__(self):
        return self.n_nodes
    
    def shuffle(self, new_order=None):
        if new_order is None:
            new_order = list(range(len(self.domains)))
            random.shuffle(new_order)
        self.domains = [self.domains[idx] for idx in new_order]
        self.train_loaders = [self.train_loaders[idx] for idx in new_order]
        self.val_loaders = [self.val_loaders[idx] for idx in new_order]
        self.test_loaders = [self.test_loaders[idx] for idx in new_order]
        self.num_samples = [self.num_samples[idx] for idx in new_order]
        G = nx.from_numpy_array(self.A.numpy())
        self.A = nx.to_numpy_array(G, new_order)
        return new_order
    
