import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import time
import os
import pickle
from tqdm import tqdm

def fourier_integral_exp_vectorized(x, c, num_points=1000):
    """
    Compute f(x)=∫₀ˣ exp(c0 + c1*sin(u) + c2*cos(u) + c3*sin(2u) + c4*cos(2u) + ...) du,
    where x can be a tensor of arbitrary shape, with each element treated as the upper integration limit.
    
    Arguments:
      x         : a tensor of arbitrary shape representing the upper integration limits
      c         : coefficient vector of shape [1+2*n_max]; c[0] is the constant term,
                  c[1], c[2] correspond to sin(u) and cos(u), and so on.
      num_points: number of points in the integration grid (default 1000), which controls accuracy
      
    Returns:
      a tensor with the same shape as x, where each element is the corresponding integral value
    """
    if not isinstance(x, torch.Tensor):
        x = torch.tensor(x, dtype=torch.float32)
    # Flatten x for convenience
    x_flat = x.view(-1)
    # Find the maximum value of all x to build a shared integration grid
    x_max = x_flat.max()
    
    # Build a shared integration grid u from 0 to x_max with num_points points
    u = torch.linspace(0, x_max.item(), steps=num_points, device=x.device, dtype=torch.float32)  # [num_points]
    
    # Check that the coefficient vector length is 1+2*n_max
    n_coeff = c.shape[0]
    if (n_coeff - 1) % 2 != 0:
        raise ValueError("The length of the coefficient vector must be 1+2*n_max")
    n_max = (n_coeff - 1) // 2
    
    # Construct the Fourier basis: a constant term, followed by sin(n*u) and cos(n*u)
    basis_list = [torch.ones_like(u)]
    for n in range(1, n_max + 1):
        basis_list.append(torch.sin(n * u))
        basis_list.append(torch.cos(n * u))
    # basis: [num_points, 1+2*n_max]
    basis = torch.stack(basis_list, dim=-1)
    
    # Compute the linear combination: g(u)= c0 + c1*sin(u) + c2*cos(u) + ...
    exponent_argument = torch.matmul(basis, c)
    # Compute the integrand: exp(g(u))
    integrand = torch.exp(exponent_argument)  # shape: [num_points]
    
    # Use the trapezoidal rule to compute the cumulative integral f(u) at grid points
    # Compute adjacent u spacings
    du = u[1:] - u[:-1]  # shape: [num_points-1]
    # Area: (integrand[i-1]+integrand[i])/2 * (u[i]-u[i-1])
    area = 0.5 * (integrand[1:] + integrand[:-1]) * du  # shape: [num_points-1]
    # f(u) is 0 at u[0], then accumulate
    f_u = torch.cat([torch.tensor([0.0], device=u.device), torch.cumsum(area, dim=0)])  # shape: [num_points]
    
    # For each x value, obtain the value from (u, f_u) by linear interpolation
    # Use torch.bucketize to find the position index of each x_flat in u
    indices = torch.bucketize(x_flat, u)
    # Ensure indices are within a valid range
    indices = indices.clamp(1, num_points - 1)
    
    # Fetch the endpoints of the interpolation interval
    u0 = u[indices - 1]
    u1 = u[indices]
    f0 = f_u[indices - 1]
    f1 = f_u[indices]
    
    # Linear interpolation
    slope = (f1 - f0) / (u1 - u0)
    f_x_flat = f0 + slope * (x_flat - u0)
    
    # Restore the original shape of x
    output = f_x_flat.view(x.shape)

    return output

def f_transform(scores, transform_coefficients, kreg, lamda):
    """
    Applies the Fourier-based transformation to a score matrix, compatible with both RAPS and APS.

    Args:
        scores (np.ndarray): A 2D NumPy array of shape (B, N), where each row contains scores.
        transform_coefficients (torch.Tensor): A 1D tensor of Fourier coefficients used in the integral transform.
        kreg (int): Regularization cutoff index. Only the top-kreg scores are unregularized.
        lamda (float): Regularization strength. Added to scores after the top-kreg.

    Returns:
        I (np.ndarray): Indices that sort the scores in descending order, shape (B, N).
        ordered (np.ndarray): Transformed scores after applying Fourier integral, shape (B, N).
        cumsum (np.ndarray): Cumulative sum of transformed scores along each row, shape (B, N).
    """
    with torch.no_grad():
        # Get indices that would sort each row in descending order
        I = scores.argsort(axis=1)[:, ::-1]  # shape: (B, N)

        # Sort the scores in descending order
        ordered = np.take_along_axis(scores, I, axis=1)  # shape: (B, N)

        # Apply regularization to elements after the top-kreg
        ordered[:, kreg:] += lamda

        # Apply the Fourier integral transform to the scores
        ordered_tensor = torch.tensor(ordered, dtype=torch.float32)
        transformed = fourier_integral_exp_vectorized(ordered_tensor, transform_coefficients)
        ordered = transformed.detach().cpu().numpy()  # convert back to NumPy

        # Compute cumulative sum along each row
        cumsum = np.cumsum(ordered, axis=1)

    return I, ordered, cumsum

def find_quantile_sample(scores, I, true_labels, alpha):
    with torch.no_grad():
        if isinstance(scores, np.ndarray):
            scores = torch.from_numpy(scores)
        if isinstance(true_labels, np.ndarray):
            true_labels = torch.from_numpy(true_labels)
        if isinstance(I, np.ndarray):
            I = torch.from_numpy(I).long()

        n = len(true_labels)
        cumsum = torch.cumsum(scores, dim=1)
        batch_indices = torch.arange(n)

        matches = (I == true_labels.unsqueeze(1))
        true_positions = matches.int().argmax(dim=1)

        
  
        E = cumsum[batch_indices, true_positions]

        quantile_idx = min(n - 1, max(0, int(torch.ceil(torch.tensor((n + 1) * (1 - alpha))) - 1)))
        sorted_E, sorted_indices = torch.sort(E)
        sample_idx = sorted_indices[quantile_idx].item()
        true_label_pos = true_positions[sample_idx].item()

    return sample_idx, true_label_pos


def get_softmax_targets(dataset_logits):
    """
    dataset_logits: Subset(TensorDataset(logits, labels))
    Returns: (softmax probabilities, labels)
    """
    with torch.no_grad():
        base = dataset_logits.dataset       # original TensorDataset
        idx = dataset_logits.indices        # index
        logits = base.tensors[0][idx]       # logits
        labels = base.tensors[1][idx]       # labels
        scores = torch.softmax(logits, dim=1)
    return scores, labels



def index_save(scores, kreg, lamda):
    with torch.no_grad():
        ordered, I = torch.sort(scores, dim=1, descending=True)
        ordered[:, kreg:] += lamda
    return I, ordered

def calculate_loss(f_model, scores, beta, tau):
    # scores: [B, C] tensor
    B, C = scores.shape
    
    # f_model on all scores, shape: [B, C]
    f_scores = f_model(scores)             # [B, C]
    f_values = torch.cumsum(f_scores,dim=1)
    # f_model on tau once
    f_tau = f_model(tau)   
    tau_value = torch.sum(f_tau)         # scalar


    #print(tau_value)
    # Compute the smooth sigmoid term for all prefix sums
    diff = (f_values - tau_value) / beta  # [B, C]

    

    sig_vals = torch.sigmoid(-diff)
    # Sum over positions, then average over batch
    total_loss = sig_vals.sum(dim=1).mean()      # scalar

    return total_loss




def sort_sum(scores):
    I = scores.argsort(axis=1)[:,::-1]
    ordered = np.sort(scores,axis=1)[:,::-1]
    cumsum = np.cumsum(ordered,axis=1) 
    return I, ordered, cumsum

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def validate(val_loader, model, print_bool):
    with torch.no_grad():
        batch_time = AverageMeter('batch_time')
        coverage = AverageMeter('RAPS coverage')
        size = AverageMeter('RAPS size')
        # switch to evaluate mode
        model.eval()
        end = time.time()
        N = 0
        for i, (x, target) in enumerate(val_loader):
            target = target.cuda()
            # compute output
            output, S = model(x.cuda())
            # measure accuracy and record loss
            cvg, sz = coverage_size(S, target)

            # Update meters
            coverage.update(cvg, n=x.shape[0])
            size.update(sz, n=x.shape[0])

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            N = N + x.shape[0]
            if print_bool:
                print(f'\rN: {N} | Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) | Cvg@RAPS: {coverage.val:.3f} ({coverage.avg:.3f}) | Size@RAPS: {size.val:.3f} ({size.avg:.3f})', end='')
    if print_bool:
        print('') #Endline

    return coverage.avg, size.avg 

def coverage_size(S,targets):
    covered = 0
    size = 0
    for i in range(targets.shape[0]):
        if (targets[i].item() in S[i]):
            covered += 1
        size = size + S[i].shape[0]
    return float(covered)/targets.shape[0], size/targets.shape[0]

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def data2tensor(data):
    imgs = torch.cat([x[0].unsqueeze(0) for x in data], dim=0).cuda()
    targets = torch.cat([torch.Tensor([int(x[1])]) for x in data], dim=0).long()
    return imgs, targets

def split2ImageFolder(path, transform, n1, n2):
    dataset = torchvision.datasets.ImageFolder(path, transform)
    data1, data2 = torch.utils.data.random_split(dataset, [n1, len(dataset)-n1])
    data2, _ = torch.utils.data.random_split(data2, [n2, len(dataset)-n1-n2])
    return data1, data2

def split2(dataset, n1, n2):
    data1, temp = torch.utils.data.random_split(dataset, [n1, dataset.tensors[0].shape[0]-n1])
    data2, _ = torch.utils.data.random_split(temp, [n2, dataset.tensors[0].shape[0]-n1-n2])
    return data1, data2

def split4(dataset, n1, n2, n3, n4):
    assert n1 + n2 + n3 + n4 == len(dataset), "n1 + n2 + n3 + n4 must equal the dataset size"
    data1, data2, data3, data4 = torch.utils.data.random_split(dataset, [n1, n2, n3, n4])
    return data1, data2, data3, data4


def get_model(modelname):
    if modelname == 'ResNet101':
        model = torchvision.models.resnet101(pretrained=True, progress=True)

    elif modelname == 'ResNet152':
        model = torchvision.models.resnet152(pretrained=True, progress=True)

    elif modelname == 'ResNeXt101':
        model = torchvision.models.resnext101_32x8d(pretrained=True, progress=True)

    elif modelname == 'VGG16':
        model = torchvision.models.vgg16(pretrained=True, progress=True)

    elif modelname == 'ShuffleNet':
        model = torchvision.models.shufflenet_v2_x1_0(pretrained=True, progress=True)

    elif modelname == 'DenseNet161':
        model = torchvision.models.densenet161(pretrained=True, progress=True)

    else:
        raise NotImplementedError

    model.eval()
    model = torch.nn.DataParallel(model).cuda()

    return model

# Computes logits and targets from a model and loader
def get_logits_targets(model, loader):
    logits = torch.zeros((len(loader.dataset), 1000)) # 1000 classes in Imagenet.
    labels = torch.zeros((len(loader.dataset),))
    i = 0
    print(f'Computing logits for model (only happens once).')
    with torch.no_grad():
        for x, targets in tqdm(loader):
            batch_logits = model(x.cuda()).detach().cpu()
            logits[i:(i+x.shape[0]), :] = batch_logits
            labels[i:(i+x.shape[0])] = targets.cpu()
            i = i + x.shape[0]
    
    # Construct the dataset
    dataset_logits = torch.utils.data.TensorDataset(logits, labels.long()) 
    return dataset_logits

def get_logits_dataset(modelname, datasetname, datasetpath, cache= '.cache/'):
    fname = cache + datasetname + '/' + modelname + '.pkl' 

    # If the file exists, load and return it.
    if os.path.exists(fname):
        with open(fname, 'rb') as handle:
            return pickle.load(handle)

    # Else we will load our model, run it on the dataset, and save/return the output.
    model = get_model(modelname)

    transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std =[0.229, 0.224, 0.225])
                    ])
    
    dataset = torchvision.datasets.ImageFolder(datasetpath, transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size = 32, shuffle=False, pin_memory=True)

    # Get the logits and targets
    dataset_logits = get_logits_targets(model, loader)

    # Save the dataset 
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    with open(fname, 'wb') as handle:
        pickle.dump(dataset_logits, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return dataset_logits
