import torch
import torch.optim as optim
import numpy as np
import math
from tqdm import trange
from alg.models import NNet2l, mytestNNet, MYtrainNNet
from alg.data import loaddataset
import random



def SAbound(k, N, bet):
    # compute the bound
    log_m_choose_k = [ math.lgamma(m + 1) - math.lgamma(m - k + 1) - math.lgamma(k + 1) for m in range(k, N)]
    log_N_choose_k = math.lgamma(N + 1) - math.lgamma(N - k + 1) - math.lgamma(k + 1) 
    coeffs = np.array(log_m_choose_k) - np.array([log_N_choose_k ]*(N-1-k+1))
    m_vec = np.array([m for m in range(k, N)])
    
    t1 = 0
    t2 = 1 
    while t2 - t1 > 1e-10:
        t = (t1 + t2) / 2
        val = 1 - (bet/(N)*np.sum(np.exp(coeffs - (N - m_vec)*np.log(t))))
        if val > 0:
            t2 = t
        else:
            t1 = t
    eps = 1 - t1
    return eps

def inv_binomial(n, k, delta):
    eps = 1e-6
    k = np.round(k).astype(int)
    p = np.linspace(eps, 1 - eps, num=10_000)
    log_terms = []
    for i in range(0, k + 1):
        log_comb = math.lgamma(n + 1) - math.lgamma(n - i + 1) - math.lgamma(i + 1) 
        log_pmf = log_comb + i * np.log(p) + (n - i) * np.log(1 - p)
        log_terms.append(log_pmf)
    valid = np.log(np.sum(np.exp(np.array(log_terms)), axis=0)) <= np.log(delta)
    binomial_ub = p[np.argmax(valid)] 
    return binomial_ub    


def check_condition(C, net, data, target, device):
    max_loss_indx, max_loss = find_max_loss(net, data, target, device)
    if max_loss <= C:
        return True, max_loss
    else:
        return False, max_loss   

def find_max_loss(net, data, target, device):
    with torch.no_grad():
        outputs = net(data)
        losses = abs( outputs - target.view(len(outputs),1) )
    max_loss_index = torch.argmax(losses)
    max_loss = torch.max(losses)
    return max_loss_index.item(), max_loss.item()


    

def MYSAalg(data_slice_idx, C, num_supp_init, n_points_dataset, learning_rate, momentum, batch_size, train_epochs, dropout_prob, device, name_data):
    
    # Initialize variables
    condition = False
    counter = 0; 
    
    # Initialize net and optimizer
    net = NNet2l(dropout_prob=dropout_prob, device=device).to(device) # net
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) # optimizer
    

    ### Dataset preprocessing
    # load data
    train, test = loaddataset(name_data) 
    data = torch.from_numpy(train[0]).float()
    targets = torch.from_numpy(train[1]).float()
    
    # shuffle data and target randomly
    idx = torch.randperm(data.shape[0])    
    data = data[idx]                       
    targets = targets[idx]               
         
    # reduce its size to the desired one
    n_train = n_points_dataset
    
    # select only a window of data of length n_train from the shuffled dataset
    data_idx_start = data_slice_idx*n_train
    data_idx_end = (data_slice_idx+1)*n_train
    data = data[data_idx_start:data_idx_end]
    targets = targets[data_idx_start:data_idx_end]
    
    # Initialize support and non_support indices randomly
    indices = list(range(n_train))
    random.shuffle(indices)             
    supp_indx = indices[:num_supp_init]
    nonsupp_indx = indices[num_supp_init:]

    
    ### meta-algorithm
    # While condition not satisfied, train with current dataset data_supp
    while condition == False:
        
        # extract dataset corresponding to current support
        data_supp        = data[supp_indx]
        data_nonsupp     = data[nonsupp_indx]
        targets_supp     = targets[supp_indx]
        targets_nonsupp  = targets[nonsupp_indx]
        
        # number of batches (changes inside the while loop)
        N_batches = math.ceil(len(supp_indx)/batch_size)
        
        
        # for each epoch do training
        for epoch in trange(train_epochs): 
            
            # shuffle current dataset
            idx = torch.randperm(data_supp.shape[0]) 
            data_supp = data_supp[idx]
            targets_supp = targets_supp[idx]
            
            # for each batch do training
            net.train()
            for batch_id in range(N_batches):
                # extract data
                start_batch_idx = batch_id*batch_size
                end_batch_idx   = min((batch_id+1)*batch_size,len(supp_indx))
                batch_data      = data_supp[start_batch_idx : end_batch_idx] 
                batch_targets   = targets_supp[start_batch_idx : end_batch_idx] 
                # train on that data
                MYtrainNNet(net, optimizer, batch_data, batch_targets)    


        # if moved all datapoints, then we exit
        net.eval()
        if len(supp_indx) == len(supp_indx) + len(nonsupp_indx):
            condition = True
        # else check if satisfy termination condition    
        else:
            condition, max_loss = check_condition(C, net, data_nonsupp, targets_nonsupp, device)
        
        # If condition not satisfied -> add one point to training 
        if condition == False:
            # Find worst point index
            max_loss_indx,max_loss = find_max_loss(net, data_nonsupp, targets_nonsupp, device) 
            change_indx = nonsupp_indx[max_loss_indx]
            
            # Add worst point index to supp_set and remove from non_supp_set
            supp_indx.append(change_indx)
            nonsupp_indx.remove(change_indx)
            
            
        # If at first iteration, compute misclassification over the dataset in nonsupp_loader with the network trained on supp_loader
        # This way we compute the bounds for SGD+Test-set directly
        if counter == 0:
            # compute fraction of points outside the strip
            test_error_for_binomial = mytestNNet(net, [data_nonsupp, targets_nonsupp], C)
            SGD_p_outside = mytestNNet(net, test, C)
            
        # update counter
        counter = counter + 1
        
    return test_error_for_binomial, SGD_p_outside, net, supp_indx, nonsupp_indx
