import os
import csv
import pdb

import hydra
import numpy as np
import wandb
from omegaconf import DictConfig
from tqdm import tqdm

import torch
from utils.experiment import setup_wandb, get_data, sync_timer
from torch_geometric.utils import to_dense_batch, add_remaining_self_loops, to_undirected, degree
from torch_scatter import scatter_sum, scatter_min


def radii(edge_index, edge_weight, num_nodes):
    deg = degree(edge_index[0], num_nodes=num_nodes, dtype=torch.long)
    weights, mask = to_dense_batch(edge_weight, edge_index[0], fill_value=1e10,
                                   max_num_nodes=torch.max(deg) + 1, batch_size=num_nodes)
    weights = weights.sort(-1).values
    acum_weights = torch.cumsum(weights, dim=1)
    sums = acum_weights + 1.
    sums /= torch.arange(1, weights.shape[1] + 1, device=weights.device)
    logic = (sums[:, :-1] >= weights[:, :-1]) & (sums[:, :-1] < weights[:, 1:])
    # assert logic.sum(1).max() == 1
    # rad = sums[torch.where(logic)]
    idx = logic.float().argmax(1)
    rad = sums[torch.arange(num_nodes, device=edge_index.device), idx]
    return rad


@hydra.main(version_base=None, config_path='./config', config_name="algo")
def main(args: DictConfig):
    setup_wandb(args)

    _, _, test_set = get_data(args.train.datapath)

    if args.train.debug:
        test_set = test_set[:20]

    open_costs = []
    trans_costs = []
    total_costs = []
    timings = []

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # warm up GPU
    for _ in range(20):
        data = next(iter(test_set)).to(device)
        edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight,
                                                           fill_value=0,
                                                           num_nodes=data.num_nodes)
        edge_index, edge_weight = to_undirected(edge_index, edge_weight,
                                                num_nodes=data.num_nodes,
                                                reduce='mean')
        rad = radii(edge_index, edge_weight, data.num_nodes)

    if args.train.write:
        log_dir = "./logs"
        output_file = os.path.join(log_dir, f"{args.train.datapath.split('/')[-1]}_simpleUFL_results.csv")
        os.makedirs(log_dir, exist_ok=True)

        with open(output_file, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['Graph_ID', 'Open_Cost', 'Trans_Cost', 'Total_Cost'])

    pbar = tqdm(test_set)
    for i, data in enumerate(pbar):
        time_single = 0.
        data = data.to(device)
        edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight,
                                                           fill_value=0,
                                                           num_nodes=data.num_nodes)
        edge_index, edge_weight = to_undirected(edge_index, edge_weight,
                                                num_nodes=data.num_nodes,
                                                reduce='mean')
        t1 = sync_timer()
        rad = radii(edge_index, edge_weight, data.num_nodes)
        time_single += sync_timer() - t1

        rng = np.random.RandomState(args.train.seed)
        torch.manual_seed(args.train.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.train.seed)

        opens = []
        trans = []
        for rep in range(args.train.repeats):
            if rep == args.train.repeats:
                t1 = sync_timer()
            fac = torch.bernoulli(torch.minimum(args.train.param * rad, torch.ones_like(rad)))
            opened = scatter_sum(fac[edge_index[1]], edge_index[0], dim_size=data.num_nodes)
            fac[opened == 0.] = 1.
            open_cost = fac.sum().item()
            trans_cost = edge_weight.clone()
            trans_cost[fac[edge_index[1]] == 0.] = 1.e10
            trans_cost, _ = scatter_min(trans_cost, edge_index[0])
            trans_cost = trans_cost.sum().item()

            if rep == args.train.repeats:
                t2 = sync_timer()
                time_single += t2 - t1

            opens.append(open_cost)
            trans.append(trans_cost)

        opens = np.mean(opens)
        trans = np.mean(trans)

        open_costs.append(opens)
        trans_costs.append(trans)
        total_costs.append(opens + trans)
        timings.append(time_single)

        if args.train.write:
            with open(output_file, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([
                    i,  # Graph ID
                    f"{opens}",
                    f"{trans}",
                    f"{opens + trans}"
                ])
                f.flush()

        pbar.set_postfix({'open': opens, 'trans': trans, 'total': opens + trans})

    wandb.log({
        'open_mean': np.mean(open_costs),
        'open_std': np.std(open_costs),
        'trans_mean': np.mean(trans_costs),
        'trans_std': np.std(trans_costs),
        'total_mean': np.mean(total_costs),
        'total_std': np.std(total_costs),
        'time_mean': np.mean(timings),
        'time_std': np.std(timings),
    })


if __name__ == '__main__':
    main()
