DEBUG = False
import os, sys, torch, wandb, pickle, numpy as np, pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, TensorDataset
import networkx as nx
from typing import Dict, Optional

from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/
path2schooldata = path2project + 'data/school_networks/'
from data.school_networks.load_school_networks import parse_s1_graph, parse_s2_graph, parse_s3_graph, parse_s4_graph, \
    create_datasets, load_obj, save_obj, parse_school_network, load_dataset


def construct_gso(adj, gso, dtype=torch.float32):
    #assert adj.ndim == 3
    if gso == 'adjacency':
        return adj
    elif gso == 'laplacian':
        D = torch.diag_embed(adj.sum(dim=1))
        return D - adj.to(dtype)
    else:
        raise ValueError(f'unknown GSO {gso} given')  # this is not precision')


class SchoolNetworks(pl.LightningDataModule):
    def __init__(self,
                 num_vertices: int,
                 train_size: int, val_size: int, test_size: int,
                 seed: int,
                 batch_size: int,
                 num_workers: int,
                 x_graphs: str, y_graphs: str,
                 x_gso: str, y_gso: str,
                 x_min_eig: Optional[float] = None, y_min_eig: Optional[float] = None):
        super().__init__()
        self.num_vertices = num_vertices
        self.train_size, self.val_size, self.test_size = train_size, val_size, test_size
        self.seed = seed
        self.batch_size = batch_size
        self.num_workers = num_workers

        assert x_gso in ['adjacency', 'laplacian'] and y_gso in ['adjacency', 'laplacian']
        valid_graphs = ["sensor_contact", "reported_contact", "reported_friend", "facebook_friend"]
        assert x_graphs in valid_graphs and \
               y_graphs in valid_graphs and x_graphs != y_graphs, f'x_graphs {x_graphs}, y_graphs {y_graphs}'
        self.x_graphs, self.y_graphs = x_graphs, y_graphs
        self.x_gso, self.y_gso = x_gso, y_gso
        self.label = self.y_gso # for backward compatability
        self.x_min_eig, self.y_min_eig = x_min_eig, y_min_eig
        self.train_dl, self.val_dl, self.test_dl = None, None, None
        self.train_ds, self.val_ds, self.test_ds = None, None, None

        self.subnetwork_masks = {'full': torch.ones(size=(1, num_vertices, num_vertices), dtype=torch.bool)}
        self.non_neg_labels = (y_gso in ['adjacency'])
        self.self_loops = (y_gso in ['laplacian', 'precision'])

    def setup(self, stage: Optional[str] = None) -> None:
        if stage in ['fit', None]:
            graphs_used = [self.x_graphs, self.y_graphs]
            graphs = {}
            if 'sensor_contact' in graphs_used:
                graphs["sensor_contact"] = parse_school_network(display_name='sensor_contact', raw_name='s1',
                                                                loading_func=parse_s1_graph, path2dir=path2schooldata)
            if 'reported_contact' in graphs_used:
                graphs["reported_contact"]: nx.to_undirected(parse_school_network(display_name='reported_contact',
                                                                                  raw_name='s2', loading_func=parse_s2_graph,
                                                                                  path2dir=path2schooldata))
            if 'reported_friend' in graphs_used:
                graphs["reported_friend"] = nx.to_undirected(parse_school_network(display_name='reported_friend',
                                                                                  raw_name='s3',
                                                                                  loading_func=parse_s3_graph,
                                                                                  path2dir=path2schooldata))
            if 'facebook_friend' in graphs_used:
                graphs["facebook_friend"] = parse_school_network(display_name='facebook_friend', raw_name='s4',
                                                                 loading_func=parse_s4_graph, path2dir=path2schooldata)
            total_samples = self.train_size + self.val_size + self.test_size
            kwargs = {'total_samples': total_samples,
                      'num_vertices': self.num_vertices,
                      'names2check': [self.y_graphs, self.x_graphs], #$['facebook_friend', 'sensor_contact'],
                      'ave_degree_range': (8, 17),
                      'sparsity_range': (.1, .2)
                      }
            assert self.x_graphs == 'sensor_contact' and self.y_graphs == 'facebook_friend', \
                f'x_graphs {self.x_graphs}, y_graphs {self.y_graphs}above ave_degree range & sparsity range and ' \
                f'overall load_dataset function only made for this case, including name'
            pl.seed_everything(self.seed)
            # sampling and creation of dataset
            tensor_dict, nx_graphs_dict = load_dataset(graphs=graphs, name=f'sensor2facebook_{total_samples}',
                                                       path2dir=path2schooldata, kwargs=kwargs)

            if self.x_graphs == 'sensor_contact':
                x = torch.log2(tensor_dict[self.x_graphs] + 1)
                x = x / x.max()
            else:
                x = tensor_dict[self.x_graphs]

            if self.y_graphs == 'sensor_contact':
                y = torch.log2(tensor_dict[self.y_graphs] + 1)
                y = y / y.max()
            else:
                y = tensor_dict[self.y_graphs]

            x = construct_gso(x, gso=self.x_gso)
            y = construct_gso(y, gso=self.y_gso)

            if self.x_min_eig is not None:
                assert self.x_gso == 'laplacian', f'must find eigenvalues for adj'
                x = x + self.x_min_eig * torch.eye(self.num_vertices).expand(total_samples, self.num_vertices, self.num_vertices)
            if self.y_min_eig is not None:
                assert self.y_gso == 'laplacian', f'must find eigenvalues for adj'
                y = y + self.y_min_eig * torch.eye(self.num_vertices).expand(total_samples, self.num_vertices, self.num_vertices)

            ds = TensorDataset(x, y)
            pl.seed_everything(self.seed)
            train, val, test = random_split(ds, [self.train_size, self.val_size, self.test_size])
            self.train_ds, self.val_ds, self.test_ds = train, val, test
            self.train_dl = DataLoader(train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) #, pin_memory=True)
            self.val_dl = DataLoader(val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) #, pin_memory=True)
            self.test_dl = DataLoader(test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) #, pin_memory=True)

    def train_dataloader(self):
        return self.train_dl

    def val_dataloader(self):
        return self.val_dl

    def test_dataloader(self):
        return self.test_dl


if __name__ == "__main__":

    dm = school_networks(num_vertices=100 if 'max' in os.getcwd() else 120,
                         train_size=100, val_size=50, test_size=50,
                         x_graphs="sensor_contact", y_graphs="facebook_friend",
                         x_gso='adjacency', y_gso='adjacency',
                         seed=50, batch_size=50 if 'max' in os.getcwd() else 120,
                         num_workers=0 if 'max' in os.getcwd() else 4,
                         x_min_eig=None, y_min_eig=None)
                         #x_min_eig = 1.0, y_min_eig = 1.0)

    dm.setup('fit')
