import threading

import numpy as np
import torch
import random
import wandb
import signal
from torch.utils.data import Dataset
from PIL import Image
from typing import Dict, Iterable, List, Optional

class TimeoutException(Exception):
    pass

def handler(signum, frame):
    raise TimeoutException()

def generate_integers_divider(n, total_sum=1000):
    if n <= 0:
        return np.array([])

    cut_points = np.random.choice(np.arange(1, total_sum), n - 1, replace=False)

    all_points = np.sort(np.concatenate(([0], cut_points, [total_sum])))

    integers = np.diff(all_points)

    return integers



def test(model, loader, ratio, device='cuda', ema=None, use_ema=False):
    if ema is not None and use_ema:
        ema.store(model)
        ema.copy_to(model)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    acc = 100.0 * correct / total
    if ema is not None and use_ema:
        ema.restore(model)

    return acc * ratio


def calculate_l2_distance(model1, model2):
    params1 = model1.state_dict()
    params2 = model2.state_dict()


    distance = 0.0
    for key in params1.keys():
        diff = params1[key] - params2[key]
        distance += torch.sum(diff ** 2)

    return torch.sqrt(distance).item()

def calculate_l2_norm(model1):
    params1 = model1.state_dict()


    norm = 0.0
    for key in params1.keys():
        norm_each = params1[key]
        norm += torch.sum(norm_each ** 2)

    return torch.sqrt(norm).item()


def split_iid_dataset(dataset, ratio_k):
    if not (0 < ratio_k < 1):
        raise ValueError("ratio_k must be in [0, 1]")

    num_samples = len(dataset)
    indices = list(range(num_samples))

    split_point = int(np.floor(ratio_k * num_samples))

    np.random.seed(42)
    np.random.shuffle(indices)

    group1_indices = indices[:split_point]
    group2_indices = indices[split_point:]

    group1_dataset = torch.utils.data.Subset(dataset, group1_indices)
    group2_dataset = torch.utils.data.Subset(dataset, group2_indices)

    return group1_dataset, group2_dataset

def iid_split_indices_ratio(dataset, old_ratios, seed=42):
    ratios = np.array(old_ratios, dtype=float)
    ratios /= ratios.sum()

    N = len(dataset)
    ideal = ratios * N

    sizes = np.floor(ideal).astype(int)

    remainder = N - sizes.sum()
    frac = ideal - np.floor(ideal)
    order = np.argsort(-frac)
    for i in range(remainder):
        sizes[order[i]] += 1
    rng = np.random.default_rng(seed)
    all_indices = np.arange(N)
    rng.shuffle(all_indices)

    splits = []
    start = 0
    for size in sizes:
        splits.append(all_indices[start:start+size].tolist())
        start += size

    return splits


def compute_grad_dict(model, images, labels, criterion, weight=1.0, weight_decay=0.0):
    model.train()
    model.zero_grad(set_to_none=True)
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss = weight * loss
    loss.backward()

    grad = {}
    for name, p in model.named_parameters():
        if p.grad is None:
            g = torch.zeros_like(p.data)
        else:
            g = p.grad.detach().clone()
            if weight_decay > 0.0:
                if (p.ndim > 1) and ("weight" in name):
                    g = g + weight_decay * p.data.detach()
        grad[name] = g
    return grad, float(loss.detach().item())


def zero_like_param_dict(model):
    return {name: p.data.new_zeros(p.data.shape) for name, p in model.named_parameters()}


def generate_random_ratios(n: int, total: float, min_val: float, max_val: float, seed: int = None):
    if seed is not None:
        random.seed(seed)

    while True:
        nums = [random.uniform(min_val, max_val) for _ in range(n)]
        s = sum(nums)
        ratios = [x * total / s for x in nums]
        if all(min_val <= r <= max_val for r in ratios):
            return ratios


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def average_params(models: List[torch.nn.Module], weights: Optional[List[float]] = None) -> Dict[str, torch.Tensor]:
    with torch.no_grad():
        name_to_param = [dict(m.named_parameters()) for m in models]
        names = list(name_to_param[0].keys())
        if weights is None:
            w = torch.tensor([1.0/len(models)] * len(models), device=next(models[0].parameters()).device)
        else:
            w = torch.tensor(weights, dtype=torch.float32, device=next(models[0].parameters()).device)
            w = w / w.sum()

        avg_state = {}
        for name in names:
            stacked = torch.stack([name_to_param[u][name].detach() for u in range(len(models))], dim=0)
            avg_state[name] = (w.view(-1, *([1] * (stacked.dim()-1))) * stacked).sum(dim=0)
        return avg_state


def grad_over_k_batches(models: List[torch.nn.Module],
                        loader: Iterable,
                        criterion,
                        device,
                        K: int = 8,
                        weights: Optional[List[float]] = None) -> Dict[str, torch.Tensor]:
    avg_state = average_params(models, weights)
    with torch.no_grad():
        clone_model = type(models[0])().to(device)
        full_state = clone_model.state_dict()
        full_state.update(avg_state)
        clone_model.load_state_dict(full_state, strict=True)
    clone_model.eval()
    total = 0
    it = iter(loader)
    for _ in range(K):
        try:
            x, y = next(it)
        except StopIteration:
            it = iter(loader)
            x, y = next(it)

        x, y = x.to(device), y.to(device)
        out = clone_model(x)
        loss = criterion(out, y)
        (loss / K).backward()
        total += 1

    grad = {n: (p.grad.detach().clone() if p.grad is not None else torch.zeros_like(p))
            for n, p in clone_model.named_parameters()}
    return grad

def grad_norm(grad: dict, norm_type: int = 2) -> float:
    grads = [g.view(-1) for g in grad.values()]
    all_grads = torch.cat(grads)
    return all_grads.norm(norm_type).pow(2).item()

def aggregate_histories(history_both_all):
    histories_I = [pair[0] for pair in history_both_all]
    histories_II = [pair[1] for pair in history_both_all]

    def _agg(histories):
        avg_hist = {}
        std_hist = {}
        for key in histories[0].keys():
            stacked = np.stack([np.array(h[key]) for h in histories])
            avg_hist[key] = stacked.mean(axis=0).tolist()
            std_hist[key] = stacked.std(axis=0).tolist()
        return avg_hist, std_hist

    return _agg(histories_I) + _agg(histories_II)

def log_algo_results(avg, std, algo_name: str, lr_: float, cfg):
    run = wandb.init(
        entity=cfg.entity,
        project=cfg.project,
        name=f"",
        reinit=True,
        config={
            "lr": lr_,
            "num_users": cfg.num_users,
            "topology": cfg.topology,
            "eval_every": cfg.eval_every,
        },
    )

    for i, step in enumerate(avg["iters"]):
        run.log({"iter": i,
            "loss/avg": avg["loss"][i],
            "grad_norm/avg": avg["grad_norm"][i]}, step=i)

    try_finish_wandb(run)

def _timeout_handler(signum, frame):
    raise TimeoutError("wandb.finish() timeout")

def try_finish_wandb(run, timeout=10):
    if run is None:
        return

    signal.signal(signal.SIGALRM, _timeout_handler)
    signal.alarm(timeout)

    try:
        run.finish()
        print("[INFO] wandb.finish")
    except TimeoutError:
        print(f"Error {timeout}s")
    except Exception as e:
        print(f"Error:{e}")
    finally:
        signal.alarm(0)







