import os

import torch
import pickle
from tqdm import tqdm
import math

import numpy as np


def assign_learning_rate(param_group, new_lr):
    param_group["lr"] = new_lr


def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length


def cosine_lr(optimizer, base_lrs, warmup_length, steps):
    if not isinstance(base_lrs, list):
        base_lrs = [base_lrs for _ in optimizer.param_groups]
    assert len(base_lrs) == len(optimizer.param_groups)
    def _lr_adjuster(step):
        for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
            if step < warmup_length:
                lr = _warmup_lr(base_lr, warmup_length, step)
            else:
                e = step - warmup_length
                es = steps - warmup_length
                lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
            assign_learning_rate(param_group, lr)
    return _lr_adjuster


def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]


def torch_save(classifier, save_path):
    if os.path.dirname(save_path) != '':
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'wb') as f:
        pickle.dump(classifier.cpu(), f)


def torch_load(save_path, device=None):
    with open(save_path, 'rb') as f:
        classifier = pickle.load(f)
    if device is not None:
        classifier = classifier.to(device)
    return classifier


def fisher_save(fisher, save_path):
    if os.path.dirname(save_path) != '':
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    fisher = {k: v.cpu() for k, v in fisher.items()}
    with open(save_path, 'wb') as f:
        pickle.dump(fisher, f)


def fisher_load(save_path, device=None):
    with open(save_path, 'rb') as f:
        fisher = pickle.load(f)
    if device is not None:
        fisher = {k: v.to(device) for k, v in fisher.items()}
    return fisher


def get_logits(inputs, classifier):
    assert callable(classifier)
    if hasattr(classifier, 'to'):
        classifier = classifier.to(inputs.device)
    return classifier(inputs)


def get_probs(inputs, classifier):
    if hasattr(classifier, 'predict_proba'):
        probs = classifier.predict_proba(inputs.detach().cpu().numpy())
        return torch.from_numpy(probs)
    logits = get_logits(inputs, classifier)
    return logits.softmax(dim=1)


class LabelSmoothing(torch.nn.Module):
    def __init__(self, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()
    


def label_for_samples(alpha_expanded, samples_outputs_reshaped, num_pairs, label1_batch, label2_batch, device):
    # Generate full one-hot encoded tensors
    alpha_first = alpha_expanded[0, 0, 0, 0, 0]
    alpha_last = alpha_expanded[0, -1, 0, 0, 0]
    # alpha_first = 0
    # alpha_last = 1
    num_classes = samples_outputs_reshaped.size(2)
    one_hot_full = torch.zeros_like(samples_outputs_reshaped)
    
    # Create indices for all sample pairs
    batch_indices = torch.arange(num_pairs, device=device)
    
    # Set one-hot for the first endpoint (α=0 -> label1)
    one_hot_full[batch_indices, 0] = alpha_first
    one_hot_full[batch_indices, 0, label1_batch] = alpha_last
    
    # Set one-hot for the second endpoint (α=1 -> label2)  
    one_hot_full[batch_indices, -1] = alpha_first
    one_hot_full[batch_indices, -1, label2_batch] = alpha_last
    
    # Create a mask to mark positions to be replaced [num_pairs_actual, resolution, 1]
    mask = torch.zeros_like(samples_outputs_reshaped[:, :, :1])  # Broadcast along the class dimension
    mask[:, 0] = 1.0    # All sample pairs' first sample point
    mask[:, -1] = 1.0   # All sample pairs' last sample point
    
    # Fusion: use one-hot at endpoints, keep original output at other positions
    samples_outputs_modified = (
        samples_outputs_reshaped * (1.0 - mask) + 
        one_hot_full * mask
    )
    return samples_outputs_modified


def pca(samples_outputs_reshaped, num_pairs, k: int = 1):
    """
    Parallelized PCA implementation
    
    Args:
        samples_outputs_reshaped: [num_pairs, resolution, num_classes]
        num_pairs: number of sample pairs (can actually be obtained from shape[0], kept for interface consistency)
        k: number of principal components to retain, output dimension is k
    
    Returns:
        pca_output_batch: [num_pairs, resolution, k]
    """
    if k is None:
        k = 1
    k = int(k)
    if k <= 0:
        raise ValueError(f"k must be a positive int, got {k}")

    # 1. Batch Centering
    # Input shape: [N, R, C]
    # We need to compute the mean along dim=1 (resolution dimension) and keep the dimension for broadcasting
    mean = samples_outputs_reshaped.mean(dim=1, keepdim=True)  # [num_pairs, 1, num_classes]
    X_centered = samples_outputs_reshaped - mean               # [num_pairs, resolution, num_classes]

    # Limit k to the available number of principal components K=min(R, C)
    max_components = min(X_centered.shape[1], X_centered.shape[2])
    k_eff = min(k, max_components)

    # 2. Batch SVD
    # PyTorch's svd can automatically handle inputs of shape [Batch, M, N]
    with torch.no_grad():
        try:
            # U: [N, R, K], S: [N, K], Vh: [N, K, C] 
            # where K = min(R, C), usually resolution > num_classes, so K=C
            U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False)

            # Take the first k principal components
            # Vh: [num_pairs, K, num_classes], take [:, :k_eff, :]
            pcs = Vh[:, :k_eff, :]  # [num_pairs, k_eff, num_classes]
            
        except RuntimeError:
            # Note: In parallel mode, if SVD fails for even one pair in the batch,
            # the entire batch will raise an error. This fallback strategy resets all.
            # For extremely high robustness, more complex masking is needed, but SVD rarely fails.
            num_classes = samples_outputs_reshaped.size(-1)
            device = samples_outputs_reshaped.device
            # Fallback: use standard basis vectors e1..ek (or fill with uniform vectors if insufficient)
            pcs = torch.zeros(num_pairs, k_eff, num_classes, device=device)
            diag_k = min(k_eff, num_classes)
            if diag_k > 0:
                eye = torch.eye(num_classes, device=device)[:diag_k]  # [diag_k, num_classes]
                pcs[:, :diag_k, :] = eye.unsqueeze(0).expand(num_pairs, -1, -1)
            if k_eff > diag_k:
                pcs[:, diag_k:, :] = 1.0 / num_classes

    # 3. Projection
    pcs = pcs.detach()

    # 3. Projection
    # X_centered: [N, R, C]
    # pcs:        [N, k_eff, C] -> transpose to [N, C, k_eff]
    pcs_t = pcs.transpose(1, 2)
    # [N, R, C] @ [N, C, k_eff] -> [N, R, k_eff]
    scores = torch.bmm(X_centered, pcs_t)

    # If k > k_eff (theoretically only happens when max_components < k), do not forcibly pad here to avoid introducing false dimensions.
    return scores


def adjust_lambda_reg(epoch, args):
    """
    Unified lambda_reg scheduling function, selecting strategy via args.warmup_type:
        - 'linear'   : linearly from 0 -> lambda_reg, taking warmup_epochs_for_lambda epochs then hold
        - 'sin_up'   : sine from 0 -> lambda_reg (0 -> π/2), taking warmup_epochs_for_lambda epochs then hold
        - 'sin_down' : hold lambda_reg for first 2/3 epochs, then sine decay to 0 in last 1/3
        - 'normal'   : keep lambda_reg constant throughout
    """
    lambda_reg_from_args = args.lambda_reg
    if lambda_reg_from_args <= 0:
        return lambda_reg_from_args

    warmup_type = getattr(args, "warmup_type", "normal")

    # 1) Linear warmup
    if warmup_type == "linear":
        if epoch < args.warmup_epochs_for_lambda:
            return lambda_reg_from_args * epoch / args.warmup_epochs_for_lambda
        else:
            return lambda_reg_from_args

    # 2) Sine warmup
    if warmup_type == "sin_up":
        if epoch < args.warmup_epochs_for_lambda:
            progress = epoch / args.warmup_epochs_for_lambda
            return lambda_reg_from_args * math.sin(progress * math.pi / 2)
        else:
            return lambda_reg_from_args

    # 3) Sine decay (sine decay to 0 in the last 1/3 epoch)
    if warmup_type == "sin_down":
        total_epochs = args.epochs
        start_decay = int(2 * total_epochs / 3)   # Hold constant for the first 2/3
        end_decay = total_epochs                  # Decay in the last 1/3

        if epoch < start_decay:
            # Hold constant value
            return lambda_reg_from_args
        elif epoch >= end_decay:
            # Decay to 0 at the end of training
            return 0.0
        else:
            # In [start_decay, end_decay], use sine from 1 -> 0
            # Progress from 0 -> 1
            progress = (epoch - start_decay) / (end_decay - start_decay)
            # cos from 0 -> π maps to 1 -> 0, then multiply by lambda_reg
            # Can also be written as 0.5*(1+cos(pi*progress))
            factor = 0.5 * (1 + math.cos(math.pi * progress))
            return lambda_reg_from_args * factor

    # 4) normal: default strategy, keep constant
    return lambda_reg_from_args

