import math
import os
import pickle
import sys
import warnings
from copy import deepcopy
from pathlib import Path
from typing import List, Tuple, Union

import numpy as np
import scipy.sparse as sp
import torch
import wandb
from rdkit import Chem
from sklearn.metrics import balanced_accuracy_score
from torch import Tensor
from torch_sparse import SparseTensor

SMALL_INT = -math.log2(sys.maxsize * 2 + 2)
LARGE_INT = -SMALL_INT
NUM_STABLE = 1e-8
ROOT_DIR = Path(__file__).parent.resolve().parent
WB_LOG_PATH = ROOT_DIR / "logs"
DATA_PATH = ROOT_DIR / "data"
WB_ENTITY = "TO-BE-REPLACED"
WB_COLLECTION = "magnet-collection"


def calculate_balanced_acc(target, logits):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        accuracy = balanced_accuracy_score(target.cpu(), logits.argmax(-1).cpu())
    return accuracy


def calculate_class_weights(target, num_classes):
    one_hot_target = torch.nn.functional.one_hot(target.long(), num_classes=num_classes)
    weight = one_hot_target.sum() / one_hot_target.sum(0)
    return weight


def parse_results(run_id, path):
    results = {"run_id": run_id}
    api = wandb.Api()
    run = api.run(path)
    df = run.history()
    df = df.groupby(["epoch"]).mean()
    results.update(df.to_dict(orient="list"))
    return results


def sort_ndarray_by_dim(values: List[List], names: List[str], sort_by: List[str], argsort=False):
    dtype = []
    for name, l in zip(names, values):
        if isinstance(l[0], str):
            dtype.append((name, np.unicode_, 64))
            continue
        dtype.append((name, type(l[0])))
    values = np.array(list(zip(*values)), dtype=dtype)
    if argsort:
        return np.argsort(values, order=sort_by, kind="stable")
    return np.sort(values, order=sort_by)


def manual_batch_to_device(batch, device):
    for key in batch.keys():
        if isinstance(batch[key], Tensor) or isinstance(batch[key], SparseTensor):
            batch[key] = batch[key].to(device)


def mol_standardize(mol: Chem.Mol, largest_comp: bool):
    smiles = Chem.MolToSmiles(mol, isomericSmiles=False, kekuleSmiles=True)
    if largest_comp:
        smiles = max(smiles.split("."), key=len)
    return smiles


def extract_blockdiag_sparse(matrix: sp.coo_matrix, block_sizes: Union[List, Tuple]):
    if len(block_sizes) == 1:
        return [matrix]

    cumulative_sizes = np.cumsum(block_sizes)
    if cumulative_sizes[-1] != matrix.shape[0]:
        raise ValueError("Sum of block sizes does not match matrix size.")

    blocks = []
    start = 0
    for size in block_sizes:
        end = start + size
        blocks.append(matrix[start:end, start:end])
        start = end

    return blocks


def smiles_from_file(file_path: Union[str, Path]):
    with open(file_path, "r") as file:
        smiles = file.readlines()
    smiles = [gt.strip("\n") for gt in smiles]
    assert all([Chem.MolFromSmiles(s) is not None for s in smiles])
    return smiles


def save_model_config_to_file(wandb_project, run_id, config_in, model):
    config = deepcopy(config_in)
    save_path = WB_LOG_PATH / wandb_project / str(run_id) / "checkpoints"
    os.makedirs(save_path)
    save_path = save_path / "load_config.pkl"
    pop_keys = [key for key in config.keys() if key not in model.__dict__.keys()]
    pop_keys = [config.pop(key) for key in pop_keys]
    with open(save_path, "wb") as handle:
        pickle.dump(config, handle, protocol=pickle.HIGHEST_PROTOCOL)
