import os
import time
import wandb
import torch
import numpy as np
from pprint import pformat
from tqdm import tqdm, trange
from omegaconf import OmegaConf
from contextlib import contextmanager
from torch_scatter import scatter
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data

from parsers import Parser, get_config
from utils.loader import (
    load_seed, load_device, load_data, load_sampler, load_evaluater
)
from utils.logger import set_log
from dataset.misc import batched_to_list

torch.set_num_threads(32)


class Sampler(object):
    def __init__(self, config, exp_name='run', seed=10):
        super().__init__()

        self.config = config
        self.log_folder_name, self.log_dir, self.ckpt_dir = set_log(self.config)

        self.seed = load_seed(seed)
        self.device = load_device()
        self.device_id = f'cuda:{self.device[0]}' if isinstance(self.device, list) else self.device
        self.full_dataset = load_data(self.config.data, return_loader=False)
        self.dataset = self.full_dataset.get_split_dataset('test')
        self.dataloader = self.dataset.get_dataloader(**self.config.data)

        self.sample_params = self.config.sample
        self.ckpt_path = self.sample_params.ckpt_path
        self.sampler = load_sampler(self.ckpt_path, self.device)
        self.evaluater = load_evaluater(
            self.dataset, exp_name=exp_name, device=self.device_id
        )

        if self.config.diffusion.target == 'binary':
            from diffusion.diffusion_utils import EmpiricalEmptyGraphGenerator

            self.sampler.initial_graph_sampler = EmpiricalEmptyGraphGenerator(
                self.full_dataset.get_split_dataset('train')
            )

    # def batched_to_list(self, bdata):
    #     ptr = bdata.ptr
    #     data_list = []
    #     for i in range(len(ptr) - 1):
    #         lower, upper = ptr[i], ptr[i + 1]
    #         data_dict = {}
    #         if bdata.x is not None:
    #             data_dict['x'] = bdata.x[lower:upper]
    #         data_dict['edge_index'] = bdata.edge_index[
    #             :, torch.logical_and(bdata.edge_index >= lower, bdata.edge_index < upper).all(0)
    #         ]
    #         data_list.append(Data(**data_dict))

    #     return data_list

    def __call__(self, sample_from_empty=False):

        # -------- Sampling --------
        self.sampler.eval()
        if self.config.diffusion.target == 'binary':
            if sample_from_empty:
                generated_data_list = self.sampler.sample(num_samples=10).cpu().to_data_list()
            else:
                generated_data_list = []
                for _, bdata in tqdm(enumerate(self.dataloader)):
                    bdata = bdata.to(self.device_id)
                    bdata = self.sampler.sample(bdata).cpu()
                    generated_data_list.extend(batched_to_list(bdata))

        else:
            generated_data_list = []
            for _, bdata in tqdm(enumerate(self.dataloader)):
                bdata = bdata.to(self.device_id)
                # import ipdb; ipdb.set_trace()
                full_edge_attr = self.sampler.sample(
                    bdata, return_all=False, device=self.device_id, **self.sample_params
                )
                is_edge_indices = full_edge_attr.nonzero(as_tuple=True)[0]
                edge_index = bdata.full_edge_index[:, is_edge_indices]
                edge_index = torch.cat([edge_index, edge_index.flip(0)],dim=-1)
                bdata.edge_index = edge_index
                generated_data_list.extend(batched_to_list(bdata.cpu()))

        metrics = self.evaluater(generated_data_list)
        print(metrics)
        wandb.log(metrics)
        return metrics

if __name__ == '__main__':
    ts = time.strftime('%b%d-%H:%M:%S', time.gmtime())
    args, unknown = Parser().parse()
    cli = OmegaConf.from_dotlist(unknown)
    config = get_config(args.config, args.seed)
    config = OmegaConf.merge(config, cli)
    print(pformat(vars(config)))

    # -------- Sample --------
    seed = config.train.seed
    exp_name = f'{args.prefix}-r.{seed}-{ts}'
    diff_type = config.diffusion.target
    run = wandb.init(
        project='AdjacencyDiff_sample',
        # group=f'{diff_type}_{data_name}',
        group=f'{diff_type}_{args.prefix}',
        name=exp_name, 
        # config=config
    )

    sampler = Sampler(config, exp_name=exp_name, seed=args.seed)
    sample_from_empty = config.sample.sample_from_empty
    metrics = sampler(sample_from_empty=sample_from_empty)