import os
import torch
import pickle
import math
import numpy as np
import torchvision   
from torch.utils.data import DataLoader
from collections import Counter
import tqdm
class NormalizeInverse(torchvision.transforms.Normalize):
    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-7)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

def corner_mask_generation(patch=None, location="RD", image_size=(3, 224, 224)):
    applied_patch = np.zeros(image_size)
    if location == "RD":  # Right-Down
        x_location = image_size[1] - patch.shape[1]
        y_location = image_size[2] - patch.shape[2]
    elif location == "RU":  # Right-Up
        x_location = image_size[1] - patch.shape[1]
        y_location = 0
    elif location == "LD":  # Left-Down
        x_location = 0
        y_location = image_size[2] - patch.shape[2]
    elif location == "LU":  # Left-Up
        x_location = 0
        y_location = 0
    applied_patch[:, x_location:x_location + patch.shape[1], y_location:y_location + patch.shape[2]] = patch
    mask = applied_patch.copy()
    mask[mask != 0] = 1.0
    return applied_patch, mask, x_location, y_location

def assign_learning_rate(param_group, new_lr):
    param_group["lr"] = new_lr


def _warmup_lr(base_lr, warmup_length, step):
    return base_lr * (step + 1) / warmup_length


def cosine_lr(optimizer, base_lrs, warmup_length, steps):
    if not isinstance(base_lrs, list):
        base_lrs = [base_lrs for _ in optimizer.param_groups]
    assert len(base_lrs) == len(optimizer.param_groups)
    def _lr_adjuster(step):
        for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
            if step < warmup_length:
                lr = _warmup_lr(base_lr, warmup_length, step)
            else:
                e = step - warmup_length
                es = steps - warmup_length
                lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
            assign_learning_rate(param_group, lr)
    return _lr_adjuster


def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]


def torch_load_old(save_path, device=None):
    with open(save_path, 'rb') as f:
        classifier = pickle.load(f)
    if device is not None:
        classifier = classifier.to(device)
    return classifier


def torch_save(model, save_path):
    if os.path.dirname(save_path) != '':
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.cpu(), save_path)


def torch_load(save_path, device=None):
    model = torch.load(save_path)
    if device is not None:
        model = model.to(device)
    return model



def get_logits(inputs, classifier):
    assert callable(classifier)
    if hasattr(classifier, 'to'):
        classifier = classifier.to(inputs.device)
    return classifier(inputs)


def get_probs(inputs, classifier):
    if hasattr(classifier, 'predict_proba'):
        probs = classifier.predict_proba(inputs.detach().cpu().numpy())
        return torch.from_numpy(probs)
    logits = get_logits(inputs, classifier)
    return logits.softmax(dim=1)


class LabelSmoothing(torch.nn.Module):
    def __init__(self, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


def print_label_distribution(new_labels):
    """
    Prints the distribution of class labels in the given list of labels.

    Args:
        new_labels (list): A list of predicted labels or new class indices.
    """
    # Count occurrences of each label
    label_counts = Counter(new_labels)

    # Print the class distribution
    print("Class-Label Distribution:")
    for label, count in sorted(label_counts.items()):
        print(f"Class {label}: {count} samples")

def relabel_dataset(subset, image_encoder, classification_head, batch_size=128, device="cuda"):
    """
    Relabels the dataset using predictions from the image encoder and classification head.

    Args:
        subset: The dataset wrapped in a Subset (PyTorch Subset object).
        image_encoder: The image encoder model (PyTorch model).
        classification_head: The classification head (PyTorch model).
        batch_size: Batch size for processing the dataset.
        device: Device to run the models on ("cuda" or "cpu").

    Returns:
        A new dataset with updated labels.
    """
    # Set models to evaluation mode
    image_encoder.eval()
    image_encoder.to(device)
    classification_head.eval()
    classification_head.to(device)
    # Access the underlying dataset (assumes ImageFolder or similar)
    dataset = subset.dataset

    # DataLoader to iterate through the dataset using subset indices
    dataloader = DataLoader(subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    new_labels = np.empty(len(dataloader.dataset), dtype=int)
    with torch.no_grad():
        i = 0
        for images, _ in tqdm.tqdm(dataloader):
            images = images.to(device)
            embeddings = image_encoder(images)
            outputs = classification_head(embeddings)
            preds = torch.argmax(outputs, dim=1)
            batch_size = preds.size(0)
            new_labels[i:i + batch_size] = preds.cpu().numpy()
            i += batch_size
    # Update the underlying dataset's labels using the "samples" attribute for ImageFolder-based datasets
    # for i in range(len(dataset)):
    #     dataset.targets[i] = new_labels[i]
    for i in subset.indices:
        dataset.targets[i] = new_labels[i]

    return dataset