import torch
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
from alg.models import NNet4l, mytestNNet, MYtrainNNet
from alg.data import loaddataset
from tqdm import trange
import random


def SAbound(k, N, bet):
   
    log_m_choose_k = [ math.lgamma(m + 1) - math.lgamma(m - k + 1) - math.lgamma(k + 1) for m in range(k, N)]
    log_N_choose_k = math.lgamma(N + 1) - math.lgamma(N - k + 1) - math.lgamma(k + 1) 
    coeffs = np.array(log_m_choose_k) - np.array([log_N_choose_k ]*(N-1-k+1))
    m_vec = np.array([m for m in range(k, N)])
    
    t1 = 0
    t2 = 1 
    while t2 - t1 > 1e-10:
        t = (t1 + t2) / 2
        val = 1 - (bet/(N)*np.sum(np.exp(coeffs - (N - m_vec)*np.log(t))))
        if val > 0:
            t2 = t
        else:
            t1 = t
    eps = 1 - t1
    return eps

def inv_binomial(n, k, delta):
    eps = 1e-6
    k = np.round(k).astype(int)
    p = np.linspace(eps, 1 - eps, num=10_000)
    log_terms = []
    for i in range(0, k + 1):
        log_comb = math.lgamma(n + 1) - math.lgamma(n - i + 1) - math.lgamma(i + 1) 
        log_pmf = log_comb + i * np.log(p) + (n - i) * np.log(1 - p)
        log_terms.append(log_pmf)
    valid = np.log(np.sum(np.exp(np.array(log_terms)), axis=0)) <= np.log(delta)
    binomial_ub = p[np.argmax(valid)]
    return binomial_ub    

def find_max_loss(net, test_loader, device='cuda'):
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = net(data)
            losses = F.nll_loss(outputs, target, reduction='none')
    max_loss_index = torch.argmax(losses)
    max_loss = torch.max(losses)
    return max_loss_index.item(), max_loss.item()


def check_condition(C, net, test_loader, device='cuda'):
    max_loss_indx, max_loss = find_max_loss(net, test_loader, device)
    if max_loss <= C:
        return True
    else:
        return False


def MYcheck_condition(C, net, data, target, device='cuda'):
    max_loss_indx, max_loss = MYfind_max_loss(net, data, target, device)
    if max_loss <= C:
        return True
    else:
        return False    

def MYfind_max_loss(net, data, target, device='cuda'):
    with torch.no_grad():
        outputs = net(data)
        losses = F.nll_loss(outputs, target, reduction='none')
    max_loss_index = torch.argmax(losses)
    max_loss = torch.max(losses)
    return max_loss_index.item(), max_loss.item()


    

     
def MYSAalg(test_loader, data_slice_idx, C, num_supp_init, n_datapoints, learning_rate, momentum, batch_size, train_epochs, dropout_prob, device, name_data):
    ## THIS CODE RUNS THE META ALGORITHM and SGD+test-set jointly
    
    # Initialize variables
    condition = False
    counter = 0; 
    
    # Initialize net and optimizer
    net = NNet4l(dropout_prob=dropout_prob, device=device).to(device) # net
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) # optimizer

    ### Dataset preprocessing
    # load raw dataset and rescale, similarly to how dataloader does it
    train, test = loaddataset(name_data) 
    data = (train.data.float()/255-0.1307)/0.3081 # rescale
    targets = train.targets.long()
    test.data = (test.data.float()/255-0.1307)/0.3081
    test.targets = test.targets.long()
    
    # shuffle data and target randomly
    idx = torch.randperm(data.shape[0])     
    data = data[idx]                        
    targets = targets[idx]                  
         
    # reduce its size to the desired one
    n_train = n_datapoints
    
    # select only a window of data of length n_train from the shuffled dataset
    data_idx_start = data_slice_idx*n_train
    data_idx_end = (data_slice_idx+1)*n_train
    data = data[data_idx_start:data_idx_end]
    targets = targets[data_idx_start:data_idx_end]
    
    # initialize support and non_support indices randomly
    indices = list(range(n_train))
    random.shuffle(indices)             
    supp_indx = indices[:num_supp_init]
    nonsupp_indx = indices[num_supp_init:]
    
    
    # meta algorithm
    while condition == False:
        
        # extract dataset corresponding to current support
        data_supp        = data[supp_indx]
        data_nonsupp     = data[nonsupp_indx]
        targets_supp     = targets[supp_indx]
        targets_nonsupp  = targets[nonsupp_indx]
        
        # number of batches (changes inside the while loop)
        N_batches = math.ceil(len(supp_indx)/batch_size)
        
      
        # for each epoch do training
        for epoch in trange(train_epochs): 
            
            # shuffle current dataset
            idx = torch.randperm(data_supp.shape[0]) 
            data_supp = data_supp[idx]
            targets_supp = targets_supp[idx]
            
            # for each batch do training
            net.train()
            for batch_id in range(N_batches): #  batchid goes from 0 to N_batches-1
                # extract data
                start_batch_idx = batch_id*batch_size
                end_batch_idx   = min((batch_id+1)*batch_size,len(supp_indx))
                batch_data      = data_supp[start_batch_idx : end_batch_idx]
                batch_targets   = targets_supp[start_batch_idx : end_batch_idx] 
                # train on that data
                MYtrainNNet(net, optimizer, batch_data, batch_targets)
        

        # if moved all datapoints, then we exit
        net.eval()
        if len(supp_indx) == len(supp_indx) + len(nonsupp_indx):
            condition = True
        else:
            # Check condition 
            condition = MYcheck_condition(C, net, data_nonsupp, targets_nonsupp, device='cpu')
            
        
        # If condition not satisfied -> add one point to training 
        if condition == False:
            # Find worst point index
            max_loss_indx,max_loss = MYfind_max_loss(net, data_nonsupp, targets_nonsupp, device=device)  # finds index of nonsupp that has max value
            change_indx = nonsupp_indx[max_loss_indx]
            
            # Add worst point index to supp_set and remove from non_supp_set
            supp_indx.append(change_indx)
            nonsupp_indx.remove(change_indx)
            
               
        # If at first iteration, compute misclassification over the dataset in nonsupp_loader with the network trained on supp_loader
        # this will serve to compute a binomial bound for SGD
        if counter == 0:
            # compute misclassification on nonsupp_loader
            test_error_for_binomial = mytestNNet(net, [data_nonsupp, targets_nonsupp])
            SGD_p_misclass = mytestNNet(net, [test.data, test.targets])
            
        # update counter
        counter = counter + 1
        
        
    return test_error_for_binomial, SGD_p_misclass, net, supp_indx, nonsupp_indx

