import math
import numpy as np
from typing import Dict, Union, List, Set

import torch
import torch.optim as optim
from torch import nn
from torch_geometric.data import Data
import pdb

class AverageMeter(object):
    """Computes and stores the average and current value."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val: float, n: int = 1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(args: Dict, optimizer: optim.Optimizer, epoch: int, lr: float):
    """Learning rate adjustment methods.

    Args:
        args (Dict): Parsed arguments.
        optimizer (Optimizer): Optimizer.
        epoch (int): Current epoch.
        lr (float): The value of the learning rate.
    """
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate**3)
        lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate**steps)

    for param_group in optimizer[0].param_groups:
        param_group["lr"] = lr


def warmup_learning_rate(
    opt: Dict[str, Union[str, float, int, List]],
    epoch: int,
    batch_id: int,
    total_batches: int,
    optimizer: optim.Optimizer,
):
    """Learning rate warmup method.

    Args:
        opt (Dict[str,Union[str,float,int,List]]): Parse arguments.
        epoch (int): Current epoch.
        batch_id (int): The number of the current batch.
        total_batches (int): The number of total batch.
        optimizer (Optimizer): Optimizer.
    """
    if opt.warm and epoch <= opt.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / (opt.warm_epochs * total_batches)
        lr = opt.warmup_from + p * (opt.warmup_to - opt.warmup_from)

        for param_group in optimizer.param_groups:
            param_group["lr"] = lr


def set_optimizer(lr: float, weight_decay: float, model: nn.Sequential):
    """Initialize the optimizer.

    Args:
        lr (float): Learning rate.
        weight_decay (float): Weight decay.
        model (nn.Sequential): Model.

    Returns:
        _type_: _description_
    """
    optimizer = []
    optimizer.append(optim.Adam([{'params':model.parameters()}], lr = lr, weight_decay=weight_decay))
    return optimizer


def calmean(dataset: Set[Data]):
    """Calculate the mean value and the standard deviation value for a regression task.

    Args:
        dataset (Set[Data]): Train set of the regression task.

    Returns:
        The mean value and the standard deviation value of the dataset.
    """
    block_size = 1000
    labels = []
    labels = [dataset[i].y for i in range(len(dataset))]
    labels_tensor = torch.stack(labels).to("cuda")
    mm = torch.mean(labels_tensor, dim=0)
    ss = torch.std(labels_tensor, dim=0)
    yy = (labels_tensor - mm) / ss
    yy = yy.squeeze()
    yy = yy.unsqueeze(1)
    num_samples, num_tasks = yy.shape 
    
    dynamic_t_list = []
    max_dist_list = []
    for task in range(num_tasks):  
        yy_task = yy[:, task].unsqueeze(1)
        weight_blocks = []
        for start in range(0, num_samples, block_size):
            end = min(start + block_size, num_samples)
            yy_block = yy_task[start:end]
            weight_block = torch.cdist(yy_block, yy_task, p=2)
            weight_blocks.append(weight_block)

        weight = torch.cat(weight_blocks, dim=0)
        flattened_weights = weight.flatten() 
        dynamic_t = torch.median(flattened_weights).item()
        dynamic_t_list.append(dynamic_t)

        max_dist = (yy[:, task].max() - yy[:, task].min()).item()
        max_dist_list.append(max_dist)

    dynamic_t = torch.tensor(dynamic_t_list, device="cuda")  # (num_tasks,)
    max_dist = torch.tensor(max_dist_list, device="cuda")  # (num_tasks,)

    '''
    num_experts = 8
    quantiles = torch.quantile(yy, torch.linspace(0, 1, num_experts + 1).cuda())
    intervals = [(float(quantiles[i]), float(quantiles[i + 1])) for i in range(num_experts)]

    intervals_overlap = []
    alpha = 0.5  # 左侧扩展比例
    beta = 0.5  # 右侧扩展比例

    for i in range(num_experts):
        left = quantiles[i]
        right = quantiles[i + 1]
        width = right - left

        if i > 0:
            prev_left = quantiles[i - 1]
            start = left - alpha * (left - prev_left)
        else:
            start = left

        if i < num_experts - 1:
            next_right = quantiles[i + 2]
            end = right + beta * (next_right - right)
        else:
            end = right

        intervals_overlap.append((float(start), float(end)))
    '''
    
    num_experts = 8
    min_val = yy.min()
    max_val = yy.max()
    step = (max_val - min_val) / num_experts
    intervals = [
        (float(min_val + i * step), float(min_val + (i + 1) * step))
        for i in range(num_experts)
    ]
    intervals[-1] = (intervals[-1][0], float(max_val + 1e-6))

    intervals_overlap = []
    for i in range(num_experts):
        left_i, right_i = intervals[i]

        # 前一个 interval
        if i > 0:
            left_prev, _ = intervals[i - 1]
            start = (left_prev + left_i) / 2
        else:
            start = left_i  # 如果是第一个，不扩展左边

        # 后一个 interval
        if i < num_experts - 1:
            _, right_next = intervals[i + 1]
            end = (right_i + right_next) / 2
        else:
            end = right_i  # 如果是最后一个，不扩展右边

        intervals_overlap.append((float(start), float(end)))
    
    interval_centers = [(start + end) / 2 for (start, end) in intervals]  # 长度: num_experts
    # Step 2: 为每个样本分配对应的 interval index
    group_labels = []
    group_nums = []

    for label in yy:
        label = float(label)
        assigned = False
        for i, (low, high) in enumerate(intervals):
           if (i < len(intervals) - 1 and low <= label < high) or \
               (i == len(intervals) - 1 and low <= label <= high):
                group_labels.append(interval_centers[i])
                group_nums.append(i)
                assigned = True
                break
        if not assigned:
            print(f"Warning: label {label} not assigned to any interval.")
            raise ValueError(f"Label {label} is out of interval bounds.")
    # 转为 tensor
    group_labels = torch.tensor(group_labels, dtype=torch.float32).cuda()
    #group_labels = (group_labels - yy.min()) / (yy.max() - yy.min())
    group_nums = torch.tensor(group_nums, dtype=torch.float32).cuda()

    return mm, ss, dynamic_t, max_dist, intervals, intervals_overlap, group_labels.unsqueeze(1), group_nums.unsqueeze(1)


def save_model(
    model: nn.Sequential,
    optimizer: optim.Optimizer,
    opt: Dict[str, Union[str, float, int, List]],
    epoch: int,
    save_file: str,
):
    """Save the model.

    Args:
        save_file (str): The address to save the model.
    """

    print("==> Saving...")
    state = {
        "opt": opt,
        "model": model.state_dict(),
        "optimizer": optimizer[0].state_dict(),
        "epoch": epoch,
    }
    torch.save(state, save_file)
    del state
