import os
import csv

import hydra
import numpy as np
import torch
from omegaconf import DictConfig
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected, add_remaining_self_loops
from torch_scatter import scatter_sum, scatter_min
from tqdm import tqdm

import wandb
from models import UnsupervisedGNN
from utils.experiment import save_run_config, setup_wandb, get_data, sync_timer

torch.set_float32_matmul_precision('high')


@hydra.main(version_base=None, config_path='./config', config_name="mpnn")
def main(args: DictConfig):
    setup_wandb(args)

    _, _, test_set = get_data(args.train.datapath)

    if args.train.debug:
        test_set = test_set[:20]

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = UnsupervisedGNN(hid_dim=args.gnn.hidden,
                            num_encode_layers=args.gnn.num_encode_layers,
                            num_conv_layers=args.gnn.num_conv_layers,
                            edge_encode_layers=args.gnn.edge_encode_layers,
                            gnn_mlp_layers=args.gnn.gnn_mlp_layers,
                            num_pred_layers=args.gnn.num_pred_layers,
                            aggr=args.gnn.aggr,
                            square_dist=args.gnn.square_dist).to(device)
    data = next(iter(test_loader)).to(device)
    # warm up GPU
    for _ in range(20):
        _ = model.predict(data)

    open_costs = []
    trans_costs = []
    total_costs = []
    gnn_timings = []

    model_dicts = os.listdir(args.train.modelpath)
    model_dicts = [m for m in model_dicts if m.startswith('best') and m.endswith('.pt')]

    log_dir = "./logs"
    output_file = os.path.join(log_dir, f"{args.train.datapath.split('/')[-1]}_GNNsimpleUFL_results.csv")
    os.makedirs(log_dir, exist_ok=True)

    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Graph_ID', 'model', 'Open_Cost', 'Trans_Cost', 'Total_Cost'])

    for run, model_dict in enumerate(model_dicts):
        state_dict = torch.load(os.path.join(args.train.modelpath, model_dict), map_location=device, weights_only=False)
        model.load_state_dict(state_dict)

        model.eval()

        batch_opens = []
        batch_trans = []
        batch_totals = []
        batch_gnn_times = []
        pbar = tqdm(test_loader)
        for i, data in enumerate(pbar):
            data = data.to(device)
            gnn_time = 0.

            t1 = sync_timer()
            rad = model.predict(data)
            t2 = sync_timer()
            gnn_time += t2 - t1

            l1, l2, l3 = model(data)
            edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight,
                                                               fill_value=0,
                                                               num_nodes=data.num_nodes)
            edge_index, edge_weight = to_undirected(edge_index, edge_weight,
                                                    num_nodes=data.num_nodes,
                                                    reduce='mean')

            rng = np.random.RandomState(args.train.seed)
            torch.manual_seed(args.train.seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(args.train.seed)

            expect_open = (l1 + l2).item()
            expect_trans = l3.item()
            opens = []
            trans = []
            for rep in range(args.train.repeats):
                if rep == args.train.repeats - 1:
                    t1 = sync_timer()
                fac = torch.bernoulli(torch.minimum(rad, torch.ones_like(rad)))
                opened = scatter_sum(fac[edge_index[1]], edge_index[0], dim_size=data.num_nodes)
                fac[opened == 0.] = 1.
                open_cost = fac.sum().item()
                trans_cost = edge_weight.clone()
                trans_cost[fac[edge_index[1]] == 0.] = 1.e10
                trans_cost, _ = scatter_min(trans_cost, edge_index[0])
                trans_cost = trans_cost.sum().item()
                if rep == args.train.repeats - 1:
                    t2 = sync_timer()
                    gnn_time += t2 - t1

                opens.append(open_cost)
                trans.append(trans_cost)

            batch_opens.append(np.mean(opens))
            batch_trans.append(np.mean(trans))
            batch_totals.append(np.mean(opens) + np.mean(trans))
            batch_gnn_times.append(gnn_time)

            with open(output_file, 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([
                    i,  # Graph ID
                    model_dict,
                    f"{batch_opens[-1]}",
                    f"{batch_trans[-1]}",
                    f"{batch_totals[-1]}"
                ])
                f.flush()

            pbar.set_postfix({'open': batch_opens[-1], 'trans': batch_trans[-1], 'total': batch_totals[-1],
                              'exp_open': expect_open, 'exp_trans': expect_trans})

        open_costs.append(np.mean(batch_opens))
        trans_costs.append(np.mean(batch_trans))
        total_costs.append(np.mean(batch_totals))
        gnn_timings.append(np.mean(batch_gnn_times))

    stats = {
        'open_mean': np.mean(open_costs),
        'open_std': np.std(open_costs),
        'trans_mean': np.mean(trans_costs),
        'trans_std': np.std(trans_costs),
        'total_mean': np.mean(total_costs),
        'total_std': np.std(total_costs),
        'time_mean': np.mean(gnn_timings),
        'time_std': np.std(gnn_timings),
    }
    print(stats)
    wandb.log(stats)


if __name__ == '__main__':
    main()
