import glob
import random
import re
from typing import Dict

import numpy as np
from sklearn.cluster import KMeans
import torch
import os
import pandas as pd
import matplotlib.pyplot as plt
import datetime
from config import CONSTRAINT_TYPES
from torch.nn import DataParallel
import math


def split_sample_by_blocks(sample_files, train_rate, block_size):
    sample_files = sorted(sample_files, key=lambda x: int(re.search(r'\d+', str(x)).group()))
    sample_files = sample_files[:]
    random.seed(0)
    train_files = []
    valid_files = []

    num_blocks = (len(sample_files) + block_size - 1) // block_size

    for i in range(num_blocks):
        # Get the current block of files
        start_idx = i * block_size
        # Ensure end_idx doesn't exceed length of sample_files
        end_idx = min((i + 1) * block_size, len(sample_files))

        block_files = sample_files[start_idx:end_idx]

        random.shuffle(block_files)
        split_index = int(train_rate * len(block_files))
        train_files.extend(block_files[:split_index])
        valid_files.extend(block_files[split_index:])

    return train_files, valid_files



def focal_loss(pre_cons, labels, weight, alpha=0.75, gamma=2):
    pos_loss = - 2 * alpha * ((1 - pre_cons + 1e-8) ** gamma) * torch.log(pre_cons + 1e-8) * (labels == 1).float()
    neg_loss = - 2 * (1 - alpha) * (pre_cons ** gamma) * torch.log(1 - pre_cons + 1e-8) * (labels == 0).float()

    masked_con_loss = (pos_loss + neg_loss) * weight[:, None]

    return masked_con_loss


def normalize_to_range(data, new_min=0, new_max=2):
    if not data:
        raise ValueError("The input list is empty.")

    old_min = min(data)
    old_max = max(data)

    if old_min == old_max:
        return [new_min] * len(data)

    normalized_data = [
        new_min + (x - old_min) * (new_max - new_min) / (old_max - old_min)
        for x in data
    ]
    return normalized_data


def grb_config(m, TimeLimit=800, Threads=1):
    m.Params.TimeLimit = TimeLimit
    m.Params.Threads = Threads
    m.Params.MIPFocus = 1

