import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from utils.utils_loss import logistic_loss
from utils.utils_models import linear_model, mlp_model
from cifar_models.resnet import resnet34, resnet50
    
def get_model(ds, mo, dim, device):
    if ds == 'cifar10':
        if mo == 'resnet':
            model = resnet34(depth=32, num_pool=8, num_classes=1).to(device)
        if mo == 'linear':
            model = linear_model(input_dim=dim, output_dim=1).to(device)
    elif ds == 'stl10':
        if mo == 'resnet':
            model = resnet34(depth=32, num_pool=24, num_classes=1).to(device)
        if mo == 'linear':
            model = linear_model(input_dim=dim, output_dim=1).to(device)
    elif ds == 'alzheimer':
        if mo == 'resnet':
            model = resnet50().to(device)
        if mo == 'linear':
            model = linear_model(input_dim=dim, output_dim=1).to(device)
    else:
        if mo == 'linear':
            model = linear_model(input_dim=dim, output_dim=1).to(device)
        elif mo == 'mlp':
            model = mlp_model(input_dim=dim, hidden_dim=300, output_dim=1).to(device)
    return model

def accuracy_check(loader, model, device):
    with torch.no_grad():
        total, num_samples = 0, 0
        for images, labels in loader:
            labels, images = labels.to(device), images.to(device)
            outputs = model(images)[:,0]
            predicted = (outputs.data >= 0).float()
            predicted[predicted == 0] = -1.0
            total += (predicted == labels).sum().item()
            num_samples += labels.size(0)
    return total / num_samples

def train_data_confidence_gen(loader, model, device, all_data_confidence):
    model.eval()
    with torch.no_grad():
        start_idx = 0
        for images, labels in loader:
            labels, images = labels.to(device), images.to(device)
            batch_size = images.shape[0]
            outputs = model(images)[:,0]
            confidence = torch.sigmoid(outputs).squeeze()
            all_data_confidence[start_idx:(start_idx+batch_size)] = confidence
            start_idx += batch_size
    return all_data_confidence, start_idx


def update_ema(model, ema_model, alpha, global_step):
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)


def exp_rampup(rampup_length):
    def warpper(epoch):
        if epoch < rampup_length:
            epoch = np.clip(epoch, 0.0, rampup_length)
            phase = 1.0 - epoch / rampup_length
            return float(np.exp(-5.0 * phase * phase))
        else:
            return 1.0

    return warpper