import torch
import dgl
import numpy as np
from loguru import logger
import time
from models.DIFUSCO_baselines import mc_decode_np
from models.SCAT_baseline import getclicnum
from erdos.utils import get_diracs, decode_clique_final_speed, decode_clique_final
import wandb
from utils.data_utils import *
from omegaconf import OmegaConf
import datetime
from utils.utils import *
from torch_geometric.utils.convert import to_scipy_sparse_matrix
from functools import partial
import torch_geometric as pyg

from GFlowNet_CombOpt.gflownet.util import MaxCliqueMDP
from GFlowNet_CombOpt.gflownet.algorithm import sample_from_logits
from einops import rearrange

class dotdict(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def gflownet_nnb_decoder(gr_hmap, gr, num_repeat=20): 

    #  num_repeat = 20
    # mis_ls, mis_top20_ls = [], []

    gbatch_rep = dgl.batch([dgl.from_networkx(gr)]*num_repeat)
    gfc = {"task":"MaxClique"}
    env = MaxCliqueMDP(gbatch_rep, dotdict(gfc))
    
    state = env.state
    temperature=1

    pf_logits= torch.hstack([torch.from_numpy(gr_hmap)]*num_repeat)

    while not all(env.done):


        action = sample_from_logits(pf_logits / temperature, gbatch_rep, state,  env.done, rand_prob=0.)
        state = env.step(action)

    curr_mis_rep = torch.tensor(env.batch_metric(state))
    curr_mis_rep = rearrange(curr_mis_rep, "(rep b) -> b rep", rep=num_repeat).float()
    return curr_mis_rep.max(dim=1)[0].item()



def NNB_evaluate(all_pred, all_target):
    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    ratio = (torch.round(all_pred) / all_target).mean()
    return  mse, ratio

def get_all_node_degrees(gr):
    gr_degree_dict = gr.degree()
    assert list(gr.nodes()) == list(range(gr.number_of_nodes()))
    node_degrees = np.array([gr_degree_dict[x] for x in range(gr.number_of_nodes())])
    return node_degrees

def get_all_clustering_coeff(gr):
    gr_cc_dict = nx.clustering(gr)
    assert list(gr.nodes()) == list(range(gr.number_of_nodes()))
    node_degrees = np.array([gr_cc_dict[x] for x in range(gr.number_of_nodes())])
    return node_degrees

def get_pagerank(gr):
    gr_pagerank_dict = nx.pagerank(gr)
    assert list(gr.nodes()) == list(range(gr.number_of_nodes()))
    node_degrees = np.array([gr_pagerank_dict[x] for x in range(gr.number_of_nodes())])
    return node_degrees

def scattering_decoder (out, edge_index, num_walkers=3, samplength=300): 
    out = torch.from_numpy(out)
    adj = to_scipy_sparse_matrix(edge_index)
    predC = []
    for walkerS in range(0,min(num_walkers, adj.get_shape()[0])): # with Numofwalkers walkers
            predC += [getclicnum(adj, out, walkerstart=walkerS,thresholdloopnodes=samplength).item()]
    return max(predC)

def egn_orig_decoder(gr_hmap, gr_data, speed=True):
    # model.eval()
    data = gr_data
    gnn_nodes = []
    gnn_edges = []
    gnn_sets = {}

    max_samples = 8
    count = 1
    gnn_times = []
    num_samples = max_samples
    t_start = time.perf_counter()
    num_graphs = data.batch.max().item()+1
    bestset = {}
    bestedges = np.zeros((num_graphs))
    maxset = np.zeros((num_graphs))

    total_samples = []
    for graph in range(num_graphs):
        curr_inds = (data.batch==graph)
        g_size = curr_inds.sum().item()
        if max_samples <= g_size: 
            samples = np.random.choice(curr_inds.sum().item(),max_samples, replace=False)
        else:
            samples = np.random.choice(curr_inds.sum().item(),max_samples, replace=True)

        total_samples +=[samples]

    t_0 = time.perf_counter()

    for k in range(num_samples):
        t_datanet_0 = time.perf_counter()
        data_prime = get_diracs(data, 1, sparse = True, effective_volume_range=0.15, receptive_field = 7)

        initial_values = data_prime.x.detach()
        data_prime.x = torch.zeros_like(data_prime.x)
        g_offset = 0
        for graph in range(num_graphs):
            curr_inds = (data_prime.batch==graph)
            g_size = curr_inds.sum().item()
            graph_x = data_prime.x[curr_inds]
            data_prime.x[total_samples[graph][k] + g_offset]=1.
            g_offset += g_size
            
 

        if speed:
            sets, set_edges, set_cardinality = decode_clique_final_speed(data_prime,(torch.from_numpy(gr_hmap).to(data_prime.x.device)), weight_factor =0.,draw=False, beam = 1)

        else:
            sets, set_edges, set_cardinality = decode_clique_final(data_prime,(torch.from_numpy(gr_hmap).to(data_prime.x.device)).float(), weight_factor =0.,draw=False)
        


        for j in range(num_graphs):
            indices = (data.batch == j)
            if (set_cardinality[j]>maxset[j]):
                    maxset[j] = set_cardinality[j].item()
                    bestset[str(j)] = sets[indices].cpu()
                    bestedges[j] = set_edges[j].item()

    t_1 = time.perf_counter()-t_0

    gnn_sets[str(count)] = bestset
    
    gnn_nodes += [maxset]
    gnn_edges += [bestedges]
    gnn_times += [t_1]

    count += 1

    t_1 = time.perf_counter()
    total_time = t_1 - t_start
    flat_list = [item for sublist in gnn_nodes for item in sublist]
    for k in range(len(flat_list)):
        flat_list[k] = flat_list[k].item()
    gnn_nodes = (flat_list)
    all_pred = torch.tensor(gnn_nodes)
    return all_pred.item()



def run():
    logger.info(conf)

    test_data = CliqueDataset(conf, "test", logger.info)
    test_data.data_type = "pyg"
    all_gt = test_data.ground_truth
    
    if conf.model.hmap_heuristic == "node_degree":
        hmap_func = get_all_node_degrees
    elif conf.model.hmap_heuristic == "clustering_coeff":
        hmap_func = get_all_clustering_coeff
    elif conf.model.hmap_heuristic == "pagerank":
        hmap_func = get_pagerank
    else:
        raise NotImplementedError()
    
    if conf.model.decoder_heuristic == "erdos_nsfe":
        decoder_func = mc_decode_np
    elif conf.model.decoder_heuristic == "scattering_1walk":
        decoder_func = partial(scattering_decoder, num_walkers=1, samplength=300)
    elif conf.model.decoder_heuristic == "scattering_2walk":
        decoder_func = partial(scattering_decoder, num_walkers=2, samplength=300)
    elif conf.model.decoder_heuristic == "scattering_3walk":
        decoder_func = partial(scattering_decoder, num_walkers=3, samplength=300)    
    elif conf.model.decoder_heuristic == "egn_orig_speed":
        decoder_func = partial(egn_orig_decoder, speed=True)
    elif conf.model.decoder_heuristic == "egn_orig_slow":
        decoder_func = partial(egn_orig_decoder, speed=False)
    elif  conf.model.decoder_heuristic == "gflownet_1repeat":
        decoder_func = partial(gflownet_nnb_decoder, num_repeat=1)
    elif  conf.model.decoder_heuristic == "gflownet_20repeat":
        decoder_func = partial(gflownet_nnb_decoder, num_repeat=20)
    else:
        raise NotImplementedError()
    
    start_time = time.time()

    all_op = []
    for gidx in range(len(test_data.corpus_graphs)):
        gr = test_data.corpus_graphs[gidx]
        elist  = test_data.graph_data_list[gidx].edge_index
        if conf.model.decoder_heuristic == "egn_orig_speed" or conf.model.decoder_heuristic == "egn_orig_slow":
            op = decoder_func(hmap_func(gr), Batch.from_data_list([test_data.graph_data_list[gidx]]))
        elif conf.model.decoder_heuristic == "gflownet_1repeat" or conf.model.decoder_heuristic == "gflownet_20repeat":
            op = decoder_func(hmap_func(gr),gr)
        else:
            op = decoder_func(hmap_func(gr), elist)
        all_op.append(op)        

    
    mse, ratio= NNB_evaluate(torch.tensor(all_op), torch.tensor(all_gt))
    logger.info(f"Run: N/A TEST mse: {mse:.6f} ratio: {ratio:.6f} Time: {time.time()-start_time:.6f}")
    wandb.log({'test_mse_loss': mse, 'test_ratio': ratio})



if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    print(OmegaConf.to_yaml(conf))
    
    task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name'))

    conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{run_time}"
    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.wandb_group,
        config={
            'seed': conf.training.seed,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
        }
    )


    set_seed(conf.training.seed, conf)

    run()
