import os
import csv

import hydra
import numpy as np
import wandb
from omegaconf import DictConfig
from tqdm import tqdm

import torch
from utils.experiment import setup_wandb, get_data, sync_timer
from torch_geometric.utils import add_remaining_self_loops, to_undirected
from torch_scatter import scatter_sum, scatter_min
from sample_algo import radii


@hydra.main(version_base=None, config_path='./config', config_name="algo")
def main(args: DictConfig):
    setup_wandb(args)

    _, _, test_set = get_data(args.train.datapath)

    if args.train.debug:
        test_set = test_set[:20]

    open_costs = []
    trans_costs = []
    total_costs = []
    timings = []

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # warm up
    for _ in range(20):
        data = next(iter(test_set)).to(device)
        edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight,
                                                           fill_value=0,
                                                           num_nodes=data.num_nodes)
        edge_index, edge_weight = to_undirected(edge_index, edge_weight,
                                                num_nodes=data.num_nodes,
                                                reduce='mean')
        rad = radii(edge_index, edge_weight, data.num_nodes)

    if args.train.write:
        log_dir = "./logs"
        output_file = os.path.join(log_dir, f"{args.train.datapath.split('/')[-1]}_RecurUFL_results.csv")
        os.makedirs(log_dir, exist_ok=True)

        with open(output_file, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['Graph_ID', 'Open_Cost', 'Trans_Cost', 'Total_Cost'])

    pbar = tqdm(test_set)
    for i, data in enumerate(pbar):
        times = 0.
        data = data.to(device)
        edge_index, edge_weight = add_remaining_self_loops(
            data.edge_index, data.edge_weight,
            fill_value=0.,
            num_nodes=data.num_nodes
        )

        edge_index, edge_weight = to_undirected(
            edge_index, edge_weight,
            num_nodes=data.num_nodes,
            reduce='mean'
        )

        # radii
        t1 = sync_timer()
        rad = radii(edge_index, edge_weight, data.num_nodes)
        times += sync_timer() - t1

        rng = np.random.RandomState(args.train.seed)
        torch.manual_seed(args.train.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.train.seed)

        opens_list = []
        trans_list = []

        # Monte Carlo Repeats
        for rep in range(args.train.repeats):

            if rep == args.train.repeats - 1:
                t1 = sync_timer()

            fac_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=data.edge_index.device)
            active_mask = torch.ones(data.num_nodes, dtype=torch.bool, device=data.edge_index.device)
            dist_to_F = torch.full((data.num_nodes,), 1.e10, device=data.edge_index.device)

            for _ in range(100):
                if not active_mask.any():
                    break

                # RecursiveUniformFL
                prob_term_rad = args.train.param * rad
                prob_term_dist = args.train.param * dist_to_F

                probs = torch.minimum(
                    torch.ones_like(rad),
                    torch.minimum(prob_term_dist, prob_term_rad)
                )

                new_candidates = torch.bernoulli(probs).bool()
                new_facilities = new_candidates & active_mask
                fac_mask = fac_mask | new_facilities

                current_edge_dists = edge_weight.clone()
                current_edge_dists[~fac_mask[edge_index[1]]] = 1.e10
                new_dists, _ = scatter_min(current_edge_dists, edge_index[0], dim_size=data.num_nodes)

                dist_to_F = torch.minimum(dist_to_F, new_dists)
                assignment_condition = dist_to_F <= (6.0 * rad)

                active_mask = active_mask & (~assignment_condition)

            fac = fac_mask.float()
            opened = scatter_sum(fac[edge_index[1]], edge_index[0], dim_size=data.num_nodes)
            fac[opened == 0.] = 1.

            current_edge_dists = edge_weight.clone()
            current_edge_dists[fac[edge_index[1]] == 0.] = 1.e10
            new_dists, _ = scatter_min(current_edge_dists, edge_index[0], dim_size=data.num_nodes)
            dist_to_F = torch.minimum(dist_to_F, new_dists)


            if rep == args.train.repeats - 1:
                t2 = sync_timer()
                times += t2 - t1

            open_cost = fac.sum().item()

            trans_cost = dist_to_F.sum().item()

            opens_list.append(open_cost)
            trans_list.append(trans_cost)

        # results
        avg_open = np.mean(opens_list)
        avg_trans = np.mean(trans_list)

        open_costs.append(avg_open)
        trans_costs.append(avg_trans)
        total_costs.append(avg_open + avg_trans)
        timings.append(times)

        if args.train.write:
            with open(output_file, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([
                    i,  # Graph ID
                    f"{avg_open}",
                    f"{avg_trans}",
                    f"{avg_open + avg_trans}"
                ])
                f.flush()

        pbar.set_postfix({'open': avg_open, 'trans': avg_trans, 'total': avg_open + avg_trans})

    wandb.log({
        'open_mean': np.mean(open_costs),
        'open_std': np.std(open_costs),
        'trans_mean': np.mean(trans_costs),
        'trans_std': np.std(trans_costs),
        'total_mean': np.mean(total_costs),
        'total_std': np.std(total_costs),
        'time_mean': np.mean(timings),
        'time_std': np.std(timings),
    })


if __name__ == '__main__':
    main()
