""" helper function

author baiyu
"""
from conf import settings  # Ensure you have this module with the constants defined
from torchvision import transforms
import os
import sys
import re
import datetime
from conf import settings
import numpy
import timm
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTConfig, AutoImageProcessor


class HuggingFaceTransform:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, image):
        processed_image = self.processor(images=image, return_tensors="pt")
        return processed_image['pixel_values'].squeeze(0)

# Wrapper class to handle different output types from Hugging Face models


class ModelWrapper(nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        outputs = self.model(x)

        # Check if the output is an instance of ModelOutput (used by Hugging Face)
        if hasattr(outputs, 'logits'):
            return outputs.logits  # Return logits if present

        # Otherwise, return the raw outputs (assumed to be logits for standard models)
        return outputs


def get_network(args):
    """ return given network
    """
    if args.dataset == 'CIFAR10':
        num_classes = 10
    elif args.dataset == 'CIFAR100':
        num_classes = 100

    if args.net == 'resnet18':
        from resnet import resnet18
        net = resnet18(num_classes)
    elif args.net == 'resnet34':
        from resnet import resnet34
        net = resnet34(num_classes)
    elif args.net == 'resnet50':
        from resnet import resnet50
        net = resnet50(num_classes)
    elif args.net == 'resnet101':
        from resnet import resnet101
        net = resnet101(num_classes)
    elif args.net == 'resnet152':
        from resnet import resnet152
        net = resnet152(num_classes)
    elif args.net == 'vits':
        net = timm.create_model(
            "deit_small_patch16_224", pretrained=True)
        # Change the head to output the correct number of classes
        net.head = nn.Linear(in_features=net.head.in_features, out_features=num_classes, bias=True)
    else:
        raise ValueError(
            'the network name you have entered is not supported yet')
    return ModelWrapper(net)


def get_training_dataloader(dataset_type='CIFAR100', batch_size=16, num_workers=2, shuffle=True, arch='resnet50'):
    """
    Return training dataloader for CIFAR datasets.

    Args:
        dataset_type: Type of dataset ('CIFAR100' or 'CIFAR10')
        batch_size: Dataloader batch size
        num_workers: Number of worker threads for DataLoader
        shuffle: Whether to shuffle the dataset
        arch: Model architecture ('resnet50', 'resnet101', 'vits', 'vitb')

    Returns:
        train_data_loader: PyTorch DataLoader object
    """
    # Select mean and std based on dataset type
    if dataset_type == 'CIFAR10':
        mean = settings.CIFAR10_TRAIN_MEAN
        std = settings.CIFAR10_TRAIN_STD
        dataset_class = torchvision.datasets.CIFAR10
    elif dataset_type == 'CIFAR100':
        mean = settings.CIFAR100_TRAIN_MEAN
        std = settings.CIFAR100_TRAIN_STD
        dataset_class = torchvision.datasets.CIFAR100
    else:
        raise ValueError(
            "Unsupported dataset type. Choose 'CIFAR100' or 'CIFAR10'.")

    # Define transforms based on arch argument
    if arch.startswith('vit'):
        # Transforms for high-resolution input (ViT)
        processor = AutoImageProcessor.from_pretrained(
            "Ahmed9275/Vit-Cifar100")
        transform_train = HuggingFaceTransform(processor)
    else:
        # Standard CIFAR transforms
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    # Load the dataset
    cifar_training = dataset_class(
        root='./data', train=True, download=True, transform=transform_train)

    # Create DataLoader
    cifar_training_loader = DataLoader(
        cifar_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar_training_loader


def get_test_dataloader(dataset_type='CIFAR100', batch_size=16, num_workers=2, shuffle=False, arch='resnet50'):
    """
    Return test dataloader for CIFAR datasets.

    Args:
        dataset_type: Type of dataset ('CIFAR100' or 'CIFAR10')
        batch_size: Dataloader batch size
        num_workers: Number of worker threads for DataLoader
        shuffle: Whether to shuffle the dataset
        arch: Model architecture ('resnet50', 'resnet101', 'vits', 'vitb')

    Returns:
        test_data_loader: PyTorch DataLoader object
    """
    # Select mean and std based on dataset type
    if dataset_type == 'CIFAR10':
        mean = settings.CIFAR10_TRAIN_MEAN
        std = settings.CIFAR10_TRAIN_STD
        dataset_class = torchvision.datasets.CIFAR10
    elif dataset_type == 'CIFAR100':
        mean = settings.CIFAR100_TRAIN_MEAN
        std = settings.CIFAR100_TRAIN_STD
        dataset_class = torchvision.datasets.CIFAR100
    else:
        raise ValueError(
            "Unsupported dataset type. Choose 'CIFAR100' or 'CIFAR10'.")

    # Define transforms based on arch argument
    if arch.startswith('vit'):
        # Transforms for high-resolution input (ViT)
        processor = AutoImageProcessor.from_pretrained(
            "Ahmed9275/Vit-Cifar100")
        transform_test = HuggingFaceTransform(processor)
    else:
        # Standard CIFAR transforms
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    # Load the dataset
    cifar_test = dataset_class(
        root='./data', train=False, download=True, transform=transform_test)

    # Create DataLoader
    cifar_test_loader = DataLoader(
        cifar_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar_test_loader


def compute_mean_std(cifar100_dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_r = numpy.dstack([cifar100_dataset[i][1][:, :, 0]
                          for i in range(len(cifar100_dataset))])
    data_g = numpy.dstack([cifar100_dataset[i][1][:, :, 1]
                          for i in range(len(cifar100_dataset))])
    data_b = numpy.dstack([cifar100_dataset[i][1][:, :, 2]
                          for i in range(len(cifar100_dataset))])
    mean = numpy.mean(data_r), numpy.mean(data_g), numpy.mean(data_b)
    std = numpy.std(data_r), numpy.std(data_g), numpy.std(data_b)

    return mean, std


class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):

        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]


def most_recent_folder(net_weights, fmt):
    """
        return most recent created folder under net_weights
        if no none-empty folder were found, return empty folder
    """
    # get subfolders in net_weights
    folders = os.listdir(net_weights)

    # filter out empty folders
    folders = [f for f in folders if len(
        os.listdir(os.path.join(net_weights, f)))]
    if len(folders) == 0:
        return ''

    # sort folders by folder created time
    folders = sorted(folders, key=lambda f: datetime.datetime.strptime(f, fmt))
    return folders[-1]


def most_recent_weights(weights_folder):
    """
        return most recent created weights file
        if folder is empty return empty string
    """
    weight_files = os.listdir(weights_folder)
    if len(weights_folder) == 0:
        return ''

    regex_str = r'([A-Za-z0-9]+)-([0-9]+)-(regular|best)'

    # sort files by epoch
    weight_files = sorted(weight_files, key=lambda w: int(
        re.search(regex_str, w).groups()[1]))

    return weight_files[-1]


def last_epoch(weights_folder):
    weight_file = most_recent_weights(weights_folder)
    if not weight_file:
        raise Exception('no recent weights were found')
    resume_epoch = int(weight_file.split('-')[1])

    return resume_epoch


def best_acc_weights(weights_folder):
    """
        return the best acc .pth file in given folder, if no
        best acc weights file were found, return empty string
    """
    files = os.listdir(weights_folder)
    if len(files) == 0:
        return ''

    regex_str = r'([A-Za-z0-9]+)-([0-9]+)-(regular|best)'
    best_files = [w for w in files if re.search(
        regex_str, w).groups()[2] == 'best']
    if len(best_files) == 0:
        return ''

    best_files = sorted(best_files, key=lambda w: int(
        re.search(regex_str, w).groups()[1]))
    return best_files[-1]
