import os
import csv

import hydra
import numpy as np
import torch
from omegaconf import DictConfig
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import remove_self_loops, degree
from torch_scatter import scatter
from torch_scatter import scatter_sum, scatter_min
from tqdm import tqdm

import wandb
from utils.experiment import setup_wandb, get_data, sync_timer


# algorithm in https://link.springer.com/article/10.1007/s00453-012-9690-y
def algorithm_o1(edge_index, edge_weight, num_nodes, rng, f=1.):
    device = edge_index.device
    nn = np.ceil(np.log2(num_nodes)).astype(np.int32)
    probs = np.minimum(2 ** np.arange(nn + 1) / num_nodes, 1)
    phi = rng.binomial(1, np.repeat(probs[None], num_nodes, axis=0))
    phi_cumsum = torch.from_numpy(
        np.concatenate([np.zeros((phi.shape[0], 1), dtype=np.int32), np.cumsum(phi, axis=1)[:, :-1]], axis=1)).to(
        device)
    phi = torch.from_numpy(phi).to(device)
    test_radius = torch.minimum(2 ** torch.arange(0, nn + 1, device=device) * f / num_nodes,
                                torch.ones(nn + 1, device=device))
    close_neighbor_mask = edge_weight[:, None] <= test_radius[None]  # E * [log(n)]
    cumsum_zero_mask = phi_cumsum[edge_index[1]] == 0  # E * [log(n)]
    edge_mask = ~close_neighbor_mask | (close_neighbor_mask & cumsum_zero_mask)
    edge_mask = edge_mask.float()
    edge_mask = scatter(edge_mask, edge_index[0], dim=0, reduce='mul', dim_size=num_nodes).bool()
    mask = (phi > 0) & edge_mask
    opens = (mask.sum(1) > 0).float()
    deg = degree(edge_index[0], num_nodes=num_nodes)
    opens[deg == 0] = 1.
    return opens


@hydra.main(version_base=None, config_path='./config', config_name="algo")
def main(args: DictConfig):
    setup_wandb(args)

    _, _, test_set = get_data(args.train.datapath)

    if args.train.debug:
        test_set = test_set[:20]

    open_costs = []
    trans_costs = []
    total_costs = []
    timings = []

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # warm up GPU
    for _ in range(20):
        data = next(iter(test_set)).to(device)
        edge_index, edge_weight = remove_self_loops(data.edge_index, data.edge_weight)
        rng = np.random.RandomState(args.train.seed)
        fac = algorithm_o1(edge_index, edge_weight, data.num_nodes, rng)

    log_dir = "./logs"
    output_file = os.path.join(log_dir, f"{args.train.datapath.split('/')[-1]}_O1_results.csv")
    os.makedirs(log_dir, exist_ok=True)

    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Graph_ID', 'Open_Cost', 'Trans_Cost', 'Total_Cost'])

    pbar = tqdm(test_set)
    for i, data in enumerate(pbar):
        data = data.to(device)
        edge_index, edge_weight = remove_self_loops(data.edge_index, data.edge_weight)
        rng = np.random.RandomState(args.train.seed)

        loop_edge_index, loop_edge_weights = add_remaining_self_loops(data.edge_index, data.edge_weight,
                                                                      fill_value=0,
                                                                      num_nodes=data.num_nodes)

        opens = []
        trans = []
        time_data = 0.
        for rep in range(args.train.repeats):
            if rep == args.train.repeats - 1:
                t1 = sync_timer()
            fac = algorithm_o1(edge_index, edge_weight, data.num_nodes, rng)

            opened = scatter_sum(fac[edge_index[1]], edge_index[0], dim_size=data.num_nodes)
            fac[opened == 0.] = 1.

            open_cost = fac.sum().item()
            trans_cost = loop_edge_weights.clone()
            trans_cost[fac[loop_edge_index[1]] == 0.] = 1.e10
            trans_cost, _ = scatter_min(trans_cost, loop_edge_index[0])
            trans_cost = trans_cost.sum().item()

            if rep == args.train.repeats - 1:
                t2 = sync_timer()
                time_data = t2 - t1

            opens.append(open_cost)
            trans.append(trans_cost)

        opens = np.mean(opens)
        trans = np.mean(trans)

        open_costs.append(opens)
        trans_costs.append(trans)
        total_costs.append(opens + trans)
        timings.append(time_data)

        with open(output_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                i,  # Graph ID
                f"{opens}",
                f"{trans}",
                f"{opens + trans}"
            ])
            f.flush()

        pbar.set_postfix({'open': opens, 'trans': trans, 'total': opens + trans})

    stats = {
        'open_mean': np.mean(open_costs),
        'open_std': np.std(open_costs),
        'trans_mean': np.mean(trans_costs),
        'trans_std': np.std(trans_costs),
        'total_mean': np.mean(total_costs),
        'total_std': np.std(total_costs),
        'time_mean': np.mean(timings),
        'time_std': np.std(timings),
    }
    wandb.log(stats)


if __name__ == '__main__':
    main()
