import os

import hydra
import numpy as np
import torch
from omegaconf import DictConfig
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_undirected, add_remaining_self_loops
from torch_scatter import scatter_sum, scatter_min
from tqdm import tqdm

import wandb
from models import UnsupervisedGNN
from utils.experiment import save_run_config, setup_wandb, get_data, sync_timer

from sklearn.cluster import KMeans
from sklearn_extra.cluster import KMedoids
from sklearn.metrics.pairwise import euclidean_distances


torch.set_float32_matmul_precision('high')


@hydra.main(version_base=None, config_path='./config', config_name="mpnn")
def main(args: DictConfig):
    setup_wandb(args)

    _, _, test_set = get_data(args.train.datapath)

    if args.train.debug:
        test_set = test_set[:20]

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             shuffle=False)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = UnsupervisedGNN(hid_dim=args.gnn.hidden,
                            num_encode_layers=args.gnn.num_encode_layers,
                            num_conv_layers=args.gnn.num_conv_layers,
                            edge_encode_layers=args.gnn.edge_encode_layers,
                            gnn_mlp_layers=args.gnn.gnn_mlp_layers,
                            num_pred_layers=args.gnn.num_pred_layers,
                            aggr=args.gnn.aggr,
                            square_dist=True).to(device)
    data = next(iter(test_loader)).to(device)
    # warm up GPU
    for _ in range(20):
        _ = model.predict(data)

    open_costs = []
    gnn_trans_costs = []
    gnn_timings = []
    kmean_trans_costs = []
    kmedoid_trans_costs = []
    kmean_timings = []
    kmedoid_timings = []
    gnn_init_kmean_trans_costs = []
    gnn_init_kmedoid_trans_costs = []
    gnn_init_kmean_timings = []
    gnn_init_kmedoid_timings = []

    model_dicts = os.listdir(args.train.modelpath)
    model_dicts = [m for m in model_dicts if m.startswith('best') and m.endswith('.pt')]

    for run, model_dict in enumerate(model_dicts):
        state_dict = torch.load(os.path.join(args.train.modelpath, model_dict), map_location=device, weights_only=False)
        model.load_state_dict(state_dict)

        model.eval()

        batch_opens = []
        batch_graph_trans = []
        batch_gnn_times = []
        batch_kmean_trans = []
        batch_kmedoid_trans = []
        batch_gnn_init_kmean_trans = []
        batch_gnn_init_kmedoid_trans = []
        batch_kmeans_times = []
        batch_kmedoid_times = []
        batch_gnn_init_kmeans_times = []
        batch_gnn_init_kmedoid_times = []
        pbar = tqdm(test_loader)
        for i, data in enumerate(pbar):
            # ================= GNN prediction =====================
            data = data.to(device)
            assert hasattr(data, 'pos')

            t1 = sync_timer()
            rad = model.predict(data)
            t2 = sync_timer()
            gnn_time = t2 - t1

            edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_weight,
                                                               fill_value=0,
                                                               num_nodes=data.num_nodes)
            edge_index, edge_weight = to_undirected(edge_index, edge_weight,
                                                    num_nodes=data.num_nodes,
                                                    reduce='mean')

            rng = np.random.RandomState(args.train.seed)
            torch.manual_seed(args.train.seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed(args.train.seed)

            opens = []
            trans = []
            pos = data.pos.cpu().numpy()
            D_squared = euclidean_distances(pos, pos) ** 2
            kmeans_inertias = []
            kmedoid_inertias = []
            gnn_init_kmeans_inertia = []
            gnn_init_kmedoids_inertia = []
            for rep in range(args.train.repeats):
                # for GNN
                if rep == args.train.repeats - 1:
                    t1 = sync_timer()
                fac = torch.bernoulli(torch.minimum(rad, torch.ones_like(rad)))
                opened = scatter_sum(fac[edge_index[1]], edge_index[0], dim_size=data.num_nodes)
                fac[opened == 0.] = 1.
                open_cost = fac.sum().item()
                trans_cost = edge_weight.clone()
                trans_cost[fac[edge_index[1]] == 0.] = 1.e10
                trans_cost, _ = scatter_min(trans_cost, edge_index[0])
                trans_cost = trans_cost ** 2
                trans_cost = trans_cost.sum().item()
                if rep == args.train.repeats - 1:
                    t2 = sync_timer()
                    batch_gnn_times.append(gnn_time + t2 - t1)
                opens.append(open_cost)
                trans.append(trans_cost)

                open_cost = int(open_cost)
                # for Kmeans
                if rep == args.train.repeats - 1:
                    t1 = sync_timer()
                kmeans = KMeans(n_clusters=open_cost, init='k-means++', n_init=1,
                                random_state=rng).fit(pos)
                if rep == args.train.repeats - 1:
                    t2 = sync_timer()
                    batch_kmeans_times.append(t2 - t1)
                kmeans_inertias.append(kmeans.inertia_)

                # for GNN init kmeans
                mask = fac.cpu().numpy() > 0.
                if rep == args.train.repeats - 1:
                    t1 = sync_timer()
                kmeans = KMeans(n_clusters=open_cost, init=pos[mask], n_init=1,
                                random_state=rng).fit(pos)
                if rep == args.train.repeats - 1:
                    t2 = sync_timer()
                    batch_gnn_init_kmeans_times.append(gnn_time + t2 - t1)
                gnn_init_kmeans_inertia.append(kmeans.inertia_)

                # for Kmedoids
                if rep == args.train.repeats - 1:
                    t1 = sync_timer()
                kmedoids = KMedoids(n_clusters=open_cost, metric="precomputed", init='k-medoids++',
                                    random_state=rng).fit(D_squared)
                if rep == args.train.repeats - 1:
                    t2 = sync_timer()
                    batch_kmedoid_times.append(t2 - t1)
                kmedoid_inertias.append(kmedoids.inertia_)

                # for gnn guided Kmedoids
                if rep == args.train.repeats - 1:
                    t1 = sync_timer()
                kmedoids = KMedoids(n_clusters=open_cost, metric="precomputed", init=D_squared[mask],
                                    random_state=rng).fit(D_squared)
                if rep == args.train.repeats - 1:
                    t2 = sync_timer()
                    batch_gnn_init_kmedoid_times.append(gnn_time + t2 - t1)
                gnn_init_kmedoids_inertia.append(kmedoids.inertia_)

            batch_opens.append(np.mean(opens))
            batch_graph_trans.append(np.mean(trans))
            batch_kmean_trans.append(np.mean(kmeans_inertias))
            batch_kmedoid_trans.append(np.mean(kmedoid_inertias))
            batch_gnn_init_kmean_trans.append(np.mean(gnn_init_kmeans_inertia))
            batch_gnn_init_kmedoid_trans.append(np.mean(gnn_init_kmedoids_inertia))

            pbar.set_postfix({'open': batch_opens[-1],
                              'gnn_trans': batch_graph_trans[-1],
                              'kmeans': batch_kmean_trans[-1],
                              'kmedoid': batch_kmedoid_trans[-1]})

        open_costs.append(np.mean(batch_opens))
        gnn_trans_costs.append(np.mean(batch_graph_trans))
        kmean_trans_costs.append(np.mean(batch_kmean_trans))
        kmedoid_trans_costs.append(np.mean(batch_kmedoid_trans))
        gnn_timings.append(np.mean(batch_gnn_times))
        kmean_timings.append(np.mean(batch_kmeans_times))
        kmedoid_timings.append(np.mean(batch_kmedoid_times))
        gnn_init_kmean_timings.append(np.mean(batch_gnn_init_kmeans_times))
        gnn_init_kmedoid_timings.append(np.mean(batch_gnn_init_kmedoid_times))
        gnn_init_kmean_trans_costs.append(np.mean(batch_gnn_init_kmean_trans))
        gnn_init_kmedoid_trans_costs.append(np.mean(batch_gnn_init_kmedoid_trans))

    stats = {
        'open_mean': np.mean(open_costs),
        'open_std': np.std(open_costs),
        'gnn_trans_mean': np.mean(gnn_trans_costs),
        'gnn_trans_std': np.std(gnn_trans_costs),
        'kmeans_trans_mean': np.mean(kmean_trans_costs),
        'kmeans_trans_std': np.std(kmean_trans_costs),
        'kmedoid_trans_mean': np.mean(kmedoid_trans_costs),
        'kmedoid_trans_std': np.std(kmedoid_trans_costs),
        'gnn_kmeans_trans_mean': np.mean(gnn_init_kmean_trans_costs),
        'gnn_kmeans_trans_std': np.std(gnn_init_kmean_trans_costs),
        'gnn_kmedoid_trans_mean': np.mean(gnn_init_kmedoid_trans_costs),
        'gnn_kmedoid_trans_std': np.std(gnn_init_kmedoid_trans_costs),

        'gnn_time_mean': np.mean(gnn_timings),
        'gnn_time_std': np.std(gnn_timings),
        'kmeans_time_mean': np.mean(kmean_timings),
        'kmeans_time_std': np.std(kmean_timings),
        'kmedoid_time_mean': np.mean(kmedoid_timings),
        'kmedoid_time_std': np.std(kmedoid_timings),
        'gnn_kmeans_time_mean': np.mean(gnn_init_kmean_timings),
        'gnn_kmeans_time_std': np.std(gnn_init_kmean_timings),
        'gnn_kmedoid_time_mean': np.mean(gnn_init_kmedoid_timings),
        'gnn_kmedoid_time_std': np.std(gnn_init_kmedoid_timings),

    }
    wandb.log(stats)


if __name__ == '__main__':
    main()
