import pickle
import torch
import networkx as nx
from torch.utils.data import Dataset, DataLoader
from omegaconf import DictConfig, open_dict, OmegaConf
import os 
import dgl
from GFlowNet_CombOpt.gflownet.main import get_alg_buffer

class GraphDataset(Dataset):
    def __init__(self, data_path, size=None):
        assert data_path is not None
        self.data_path = data_path
        self.graphs = pickle.load(open(data_path, 'rb'))


        if size is not None:
            assert size > 0
            self.graphs = self.graphs[:size]
            
        self.num_graphs = len(self.graphs)

    def __getitem__(self, idx):
        return  dgl.from_networkx(self.graphs[idx])

    def __len__(self):
        return self.num_graphs


def get_data_loaders(conf, our_conf, printer=None):
    data_path = our_conf.dataset.path
    dataset_name = our_conf.dataset.name
    if printer:
            printer(f'Loading graphs from {data_path}/{dataset_name}')
    
    train_data_path = f'{data_path}/{dataset_name}/graphs/train.pkl'

    val_data_path = f'{data_path}/{dataset_name}/graphs/val.pkl'



    trainset = GraphDataset(train_data_path, size=conf.trainsize)

    valset = GraphDataset(val_data_path,  size=conf.testsize)

    collate_fn = lambda graphs: dgl.batch(graphs)
    train_batch_size = 1 if conf.same_graph_across_batch else conf.batch_size_interact
    train_loader = DataLoader(trainset, batch_size=train_batch_size,
            shuffle=conf.shuffle, collate_fn=collate_fn, drop_last=False,
            num_workers=conf.num_workers, pin_memory=True)


    val_loader = DataLoader(valset, batch_size=conf.test_batch_size,
             shuffle=False, collate_fn=collate_fn, num_workers=conf.num_workers, pin_memory=True)
    return train_loader, val_loader


def get_test_data_loaders(conf, our_conf, printer=None):
    data_path = our_conf.dataset.path
    dataset_name = our_conf.dataset.name
    if printer:
            printer(f'Loading graphs from {data_path}/{dataset_name}')
    
    test_data_path = f'{data_path}/{dataset_name}/graphs/test.pkl'

    testset = GraphDataset(test_data_path,  size=conf.testsize)

    collate_fn = lambda graphs: dgl.batch(graphs)

    test_loader = DataLoader(testset, batch_size=conf.test_batch_size,
             shuffle=False, collate_fn=collate_fn, num_workers=conf.num_workers, pin_memory=True)

    return test_loader




class GFNET(torch.nn.Module):
    def __init__(self, conf, gmn_config=None):
        super(GFNET, self).__init__()
        self.our_conf = conf
        self.init_GFNET_defaut_config()
        self.overwrite_relevant_args()
        self.device = self.our_conf.training.device
        self.alg, self.buffer = get_alg_buffer(self.conf, self.device)


    def init_GFNET_defaut_config(self):
        self.conf =  OmegaConf.load(f"GFlowNet_CombOpt/gflownet/configs/main.yaml")

    def overwrite_relevant_args(self):
        open_dict(self.conf)
        self.conf.seed = self.our_conf.training.seed
        self.conf.eval = True
        self.conf.hidden_dim = 10
        self.conf.hidden_layers = 5
        self.conf.bs = self.our_conf.training.batch_size
        self.conf.tbs = self.our_conf.training.batch_size
        self.conf.d = self.our_conf.training.device
        
        self.conf.work_directory = os.getcwd()
        assert self.conf.arch in ["gin"]
        self.conf.device = self.conf.d
        # log reward shape
        self.conf.reward_exp = self.conf.rexp
        self.conf.reward_exp_init = self.conf.rexpit
        if self.conf.anneal in ["lin"]:
            self.conf.anneal = "linear"
        # training    
        self.conf.batch_size = self.conf.bs
        self.conf.batch_size_interact = self.conf.bsit
        self.conf.leaf_coef = self.conf.lc
        self.conf.same_graph_across_batch = self.conf.sameg

        # data
        self.conf.test_batch_size = self.conf.tbs

        # add None values
        self.conf.trainsize = None
        self.conf.testsize = None
        
        # add task
        self.conf.task = "MaxClique"

        # epochs
        self.conf.epochs = 20

        del self.conf.d, self.conf.rexp, self.conf.rexpit, self.conf.bs
        del self.conf.bsit, self.conf.lc, self.conf.sameg, self.conf.tbs
        
        self.conf.num_repeat = self.our_conf.model.num_repeat

    