import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb
from collections import defaultdict 
from erdos.utils import get_diracs, decode_clique_final, decode_clique_final_speed




def evaluate_erdos(conf, model, sampler):
    model.eval()

    gnn_nodes = []
    gnn_edges = []
    gnn_sets = {}


    max_samples = 8
    count = 1
    gnn_times = []
    num_samples = max_samples
    t_start = time.perf_counter()
    n_batches = sampler.create_batches(shuffle=False)
    for www in range(n_batches):
        data, corpus_batch_data_node_sizes, _, _, _ = sampler.fetch_batched_data_by_id(www)

        num_graphs = data.batch.max().item()+1
        bestset = {}
        bestedges = np.zeros((num_graphs))
        maxset = np.zeros((num_graphs))

        total_samples = []
        for graph in range(num_graphs):
            curr_inds = (data.batch==graph)
            g_size = curr_inds.sum().item()
            if max_samples <= g_size: 
                samples = np.random.choice(curr_inds.sum().item(),max_samples, replace=False)
            else:
                samples = np.random.choice(curr_inds.sum().item(),max_samples, replace=True)

            total_samples +=[samples]

        t_0 = time.perf_counter()

        for k in range(num_samples):
            t_datanet_0 = time.perf_counter()
            data_prime = get_diracs(data, 1, sparse = True, effective_volume_range=0.15, receptive_field = 7)
    
            initial_values = data_prime.x.detach()
            data_prime.x = torch.zeros_like(data_prime.x)
            g_offset = 0
            for graph in range(num_graphs):
                curr_inds = (data_prime.batch==graph)
                g_size = curr_inds.sum().item()
                graph_x = data_prime.x[curr_inds]
                data_prime.x[total_samples[graph][k] + g_offset]=1.
                g_offset += g_size
                
            retdz = model(None, 
                        None, 
                        None, 
                        None, 
                        data_prime, 
                        corpus_batch_data_node_sizes, 
                        None,
                        None )
            

            if conf.model.speed:
                sets, set_edges, set_cardinality = decode_clique_final_speed(data_prime,(retdz["output"][0]), weight_factor =0.,draw=False, beam = conf.model.decoder_steps)

            else:
                sets, set_edges, set_cardinality = decode_clique_final(data_prime,(retdz["output"][0]), weight_factor =0.,draw=False)
            


            for j in range(num_graphs):
                indices = (data.batch == j)
                if (set_cardinality[j]>maxset[j]):
                        maxset[j] = set_cardinality[j].item()
                        bestset[str(j)] = sets[indices].cpu()
                        bestedges[j] = set_edges[j].item()

        t_1 = time.perf_counter()-t_0
        gnn_sets[str(count)] = bestset
        
        gnn_nodes += [maxset]
        gnn_edges += [bestedges]
        gnn_times += [t_1]

        count += 1

    t_1 = time.perf_counter()
    total_time = t_1 - t_start
    logger.info(f"Time taken: {total_time}")
    logger.info(f"Average time per graph: {total_time/(len(sampler.ground_truth))}")
    flat_list = [item for sublist in gnn_nodes for item in sublist]
    for k in range(len(flat_list)):
        flat_list[k] = flat_list[k].item()
    gnn_nodes = (flat_list)
    all_pred = torch.tensor(gnn_nodes)
    all_target = torch.tensor(sampler.ground_truth)
    decoder_ratio = (torch.round(all_pred) / all_target).mean()
    decoder_mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    solver_mse_std = (((all_target - all_pred)**2).std() / np.sqrt(len(all_target))).item()

    decoder_mae = torch.nn.functional.l1_loss(all_target, all_pred, reduction="mean").item()
    decoder_rankcorr = kendalltau(all_pred.cpu().tolist(), all_target.cpu().tolist())[0]
    decoder_acc = (all_target == torch.round(all_pred)).sum() / len(all_target)



    return solver_mse_std, decoder_ratio, decoder_mse, decoder_rankcorr, decoder_mae, decoder_acc


def probabilty_evaluate(conf, model, sampler):
    model.eval()
    n_batches = sampler.create_batches(shuffle=False)
    all_target = []
    all_pred = []
    for i in range(n_batches):
        (
            batch_data,
            batch_data_sizes,
            _,
            target,
            batch_data_adj,
        ) = sampler.fetch_batched_data_by_id(i)
        all_target.append(target)
        batch_data_prime = get_diracs(batch_data, conf.training.diracs_N, sparse=True, \
                                        effective_volume_range=conf.training.diracs_effective_range,\
                                        receptive_field=conf.model.numlayers+1)
        batch_data = batch_data.cpu()
        batch_data_prime = batch_data_prime.cuda()
        out = model(None, 
            None, 
            None, 
            None, 
            batch_data_prime, 
            batch_data_sizes, 
            None,
            None )
        probs = list(torch.split(out['pre_norm_x'][0], batch_data_sizes))
        for pp in range(len(probs)):
            if (probs[pp] == 0).all():
                probs[pp] = 1-probs[pp]
            
        stacked_probs =  torch.stack([F.pad(x, pad=(0, model.conf.dataset.max_set_size - x.shape[0])) for x in probs])
        hard_indicators = stacked_probs 
        hard_indicators[hard_indicators>0.5] = 1
        hard_indicators[hard_indicators<=0.5] = 0
        all_pred.append(hard_indicators.sum(dim=-1))


    all_target = torch.cat(all_target, dim=0)
    all_pred = torch.cat(all_pred, dim=0)
    ratio = (torch.round(all_pred) / all_target).mean()
    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    mae = torch.nn.functional.l1_loss(all_target, all_pred, reduction="mean").item()
    std = (((all_target - all_pred)**2).std() / np.sqrt(len(all_target))).item()
    rankcorr = kendalltau(all_pred.cpu().tolist(), all_target.cpu().tolist())[0]
    prob_acc = (all_target == torch.round(all_pred)).sum() / len(all_target)


    return std, ratio, mse, rankcorr, mae, prob_acc

        

def train():
    train_data = CliqueDataset(conf, "train", logger.info)
    val_data = CliqueDataset(conf, "val", logger.info)
    train_data.data_type = "pyg"
    val_data.data_type = "pyg"
    es = EarlyStoppingModule(conf.base_dir, conf.task.name, logger=logger)
    logger.info(es.__dict__)
    logger.info(f"This uses the {conf.model.classPath}.{conf.model.name} model")
    logger.info(conf)
    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf).to(conf.training.device)
    logger.info(model)
    if conf.training.wandb_watch:
        wandb.watch(model, log_freq=1)
    logger.info(f"no. of params in model: {sum([p.numel() for p in model.parameters()])}")
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=conf.training.learning_rate,
        weight_decay=conf.training.weight_decay,
    )
    if conf.training.overwrite:
        checkpoint = None
    else:
        checkpoint = es.load_latest_model()

    if not checkpoint:
        es.save_initial_model(model)
        run = 0
    else:
        if es.should_stop_now:
            logger.info("Training has been completed. This logfile can be deleted.")
            return
        else:
            model.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint["optim_state_dict"])
            run = checkpoint["epoch"] + 1

    best_val_mse = 1e5
    best_val_ratio = -1



    while conf.training.run_till_early_stopping and run < conf.training.num_epochs:
        model.train()
        start_time = time.time()
        n_batches = train_data.create_batches(shuffle=True)
        epoch_loss = 0


        start_time = time.time()
        for i in range(n_batches):
            (   
                corpus_batch_data, 
                corpus_batch_data_node_sizes, 
                corpus_batch_data_edge_sizes, 
                batch_target, 
                corpus_batch_adj
            ) = train_data.fetch_batched_data_by_id(i)
            
            optimizer.zero_grad()
            batch_data_prime = get_diracs(corpus_batch_data, conf.training.diracs_N, sparse=True, \
                                        effective_volume_range=conf.training.diracs_effective_range,\
                                        receptive_field=conf.model.numlayers+1)
            losses, pred = model(train_data.packed_query_graphs, 
                        train_data.query_graph_node_sizes, 
                        train_data.query_graph_edge_sizes, 
                        train_data.query_adj_list, 
                        batch_data_prime, 
                        corpus_batch_data_node_sizes, 
                        corpus_batch_data_edge_sizes,
                        corpus_batch_adj)

            if model.conf.model.SUP:
                if model.conf.model.name in ['EGNORIG'] : 
                    sup_loss = torch.nn.functional.mse_loss(pred,batch_target)
                    losses = losses + sup_loss
                else:
                    raise NotImplementedError()
            losses.backward()
            optimizer.step()
            epoch_loss = epoch_loss + losses.item()

        logger.info(
            f"Run: {run} train loss: {epoch_loss/n_batches:.2f} Time: {time.time()-start_time:.2f}",
        )
        start_time = time.time()
        solver_ratio, solver_mse, _, _, _ = evaluate_erdos(conf, model, val_data)
        sum_p_ratio, sum_p_mse, _, _, _ = probabilty_evaluate(conf, model, val_data)
        if model.conf.model.SUP:
            mse = sum_p_mse
            ratio = sum_p_ratio
        else:
            mse = solver_mse
            ratio = solver_ratio
            
        if mse < best_val_mse:
            best_val_mse = mse
            best_val_ratio = ratio

        logger.info(f"Run: {run} VAL mse: {mse:.6f}  ratio: {ratio:.6f} solver_mse: {solver_mse:.6f} solver_ratio: {solver_ratio:.6f} sum_p_mse: {sum_p_mse:.6f} sum_p_ratio: {sum_p_ratio:.6f} Time: {time.time()-start_time:.6f}")

        log_dict = {'val_solver_mse': solver_mse,
                    'val_solver_ratio': solver_ratio,
                    'val_sum_p_ratio': sum_p_ratio,
                    'val_sum_p_mse': sum_p_mse,
                    'best_val_mse': best_val_mse,
                    'best_val_ratio': best_val_ratio,
                    'train_loss': epoch_loss/n_batches
                    }
         
        wandb.log(log_dict)

        if conf.training.run_till_early_stopping:
            es_score = -mse
            if es.check([es_score], model, run, optimizer):
                break
            run += 1
        
    ckpt = es.load_best_model()
    model.load_state_dict(ckpt["model_state_dict"])
    test_data = CliqueDataset(conf, "test", logger.info)
    test_data.data_type = "pyg"
    

    solver_ratio, solver_mse, _, _, _ = evaluate_erdos(conf, model, val_data)
    sum_p_ratio, sum_p_mse, _, _, _ = probabilty_evaluate(conf, model, val_data)
    logger.info(f"Run: {run} VAL mse: {mse:.6f} ratio: {ratio:.6f} solver_mse: {solver_mse:.6f} solver_ratio: {solver_ratio:.6f} sum_p_mse: {sum_p_mse:.6f} sum_p_ratio: {sum_p_ratio:.6f} Time: {time.time()-start_time:.6f}")
    solver_ratio, solver_mse, _, _, _ = evaluate_erdos(conf, model, test_data)
    sum_p_ratio, sum_p_mse, _, _, _ = probabilty_evaluate(conf, model, test_data)
    logger.info(f"Run: {run} TEST mse: {mse:.6f} ratio: {ratio:.6f} solver_mse: {solver_mse:.6f} solver_ratio: {solver_ratio:.6f} sum_p_mse: {sum_p_mse:.6f} sum_p_ratio: {sum_p_ratio:.6f} Time: {time.time()-start_time:.6f}")
    wandb.log({'test_mse_loss': mse, 'test_ratio': ratio, 'test_solver_mse': solver_mse, 'test_solver_ratio': solver_ratio, 'test_sum_p_mse': sum_p_mse, 'test_sum_p_ratio': sum_p_ratio})



if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    print(OmegaConf.to_yaml(conf))
    
    task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    if conf.model.EQ:
        conf.task.name = f"{conf.model.name}_EQ_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    else:
        conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"

    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.wandb_group,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
        }
    )


    set_seed(conf.training.seed, conf)
    train()
