import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import time

def distorted_greedy(defender, args):
    defender.reset_pre_computed()
    gamma = args.gamma
    K = args.K
    eps = args.eps
    dg_batch_size = args.dg_batch_size
    
    if args.debug:
        print(f"Debugging mode ON! Returning first {K} from distorted greedy")
        return list(np.arange(K))

    N = len(defender.dataset)
    s = int(round(N / K * np.log(1 / eps))) # number of random samples in an iteration
    print("\tEps is", eps)
    print("\tNumber of samples for Distorted Greedy:", s)
    S = []
    S_c = [i for i in range(N)]     # D\S

    gS = None
    print("\tIn Distorted Greedy!")
    print(f"\tValue of gamma/K is {float(gamma/K)}")
    for i in range(K):
        # if gS==None:
        #     gS = defender.compute_g(S)
        start = time.time()
        max_value = -1
        # Use an index_counter to keep track of indices in dataset!
        index_counter = 0
        
        R = np.random.choice(S_c, min(s, len(S_c)), replace=False).tolist()
        subset_R = Subset(defender.dataset, R)
        dg_batch_size = len(R) if dg_batch_size > len(R) else dg_batch_size
        dl = DataLoader(subset_R, batch_size=dg_batch_size, shuffle=False)
        if (i+1)%50==0:
            print(f"\tFor i={i+1}/{K} multiplicative coefficient is {float((1-gamma/K)**(K-i-1))}")
        start = time.time()
        for ind, (x, y) in enumerate(dl):
            if (ind+1)%100==0:
                end = time.time()
                print(f"\tFor S = {len(S)+1}, finished {ind+1}/{len(dl)} steps. Time taken since last lap is: {round(end-start, 3)}")
                start = time.time()
            if (ind+1)%100==0:
                defender.timer=True
                start = time.time()
            if defender.adversary.attack_model_type in ["pgd", "advgan"]:
                values = calc_gc_modular(defender, gamma, K, i, x, y)
                value, idx = torch.max(values, dim=0)
                idx = idx+index_counter
            else:
                value = 0
            if (ind+1)%100==0:
                end = time.time()
                print("\t\t\tMax Value is:", float(max_value))
                defender.timer=False
            if value > max_value:
                max_value = value
                idx_best = R[idx]
            index_counter += x.shape[0]
        if max_value >= 0:
            if (i+1)%50==0:
                print("\t\tMax objective value for current iteration is", float(max_value))
                # print("\t\tBest index is", int(idx_best))
            S = S + [int(idx_best)]
            S_c.remove(int(idx_best))
        else:
            if (i+1)%50==0:
                print("\t\tMax value is ", float(max_value), " hence not adding to set!")
        if (i+1)%50==0:
            print(f"\t\tSet size is {len(S)}.")
        end = time.time()
        print(f"Time taken for {i+1} is : {round(end-start, 3)}")
    print("\tS is: ", S)
    return S

def distorted_greedy_pointwise(defender, args):
    start = time.time()
    defender.reset_pre_computed()
    gamma = args.gamma
    K = args.K
    N = len(defender.dataset)

    S = []
    S_c = [i for i in range(N)]     # D\S

    print("\tIn Distorted Greedy!")
    print(f"\tValue of gamma/K is {float(gamma/K)}")

    # X = defender.dataset.data.dtype(torch.float32)[:,None,:,:]
    # y = defender.dataset.targets
    if args.debug:
        print(f"Debugging mode ON! Returning first {K} from distorted greedy")
        return list(np.arange(K))

    adv_batch_size = args.dg_batch_size
    pert_dl = DataLoader(defender.dataset, batch_size=adv_batch_size, shuffle=False)
    # num_batches = int(np.ceil(N/adv_batch_size))
    g_ds = None
    c_ds = None
    print("\tStarting adv perturbations ...")
    start_per = time.time()
    # for i in range(num_batches):
    for i, (x, y) in enumerate(pert_dl):
        if (i+1)%5==0:
            print(f"\t\tPerturbing batch {i+1}/{len(pert_dl)}")
            start_per_cur = time.time()
        defender.classifier.eval()
        x = x.to(defender.device)
        y = y.to(defender.device)
        if g_ds is None and c_ds is None:
            g_ds = defender.compute_g_on_batch(x, y).detach().cpu()
            c_ds = defender.compute_c_on_batch(x, y).detach().cpu()
        else:
            g_ds = torch.cat((g_ds, defender.compute_g_on_batch(x, y).detach().cpu()))
            c_ds = torch.cat((c_ds, defender.compute_c_on_batch(x, y).detach().cpu()))
        if (i+1)%5==0:
            end_per_cur = time.time()
            print(f"\t\t Time taken for {i+1} is {end_per_cur-start_per_cur}")
    end_per = time.time()
    print(f"\tEnding adv perturbations ... time taken is {end_per-start_per}")
    g_Sc = g_ds
    c_Sc = c_ds
    # print(g_ds.shape, c_ds.shape)
    # print(g_ds)
    # print(c_ds)
    start_sel = time.time()
    for i in range(K):
        max_value = -1
        
        if (i+1)%500==0:
            print(f"\tFor i={i+1}/{K} multiplicative coefficient is {float((1-gamma/K)**(K-i-1))}")
        
        values = (1-gamma/K)**(K-i-1)*(g_Sc) - c_Sc
        max_value, idx_best = torch.max(values, dim=0)
        
        if max_value >= 0:
            if (i+1)%500==0:
                print("\t\tMax objective value for current iteration is", float(max_value))
                # print("\t\tBest index is", int(idx_best))
            S = S + [S_c[int(idx_best)]]
            S_c.remove(S_c[int(idx_best)])
            g_Sc = g_ds[S_c]
            c_Sc = c_ds[S_c]
        else:
            print("\t\tIteration", str(i+1) ,"Max value is ", float(max_value), " hence not adding to set!")
        if (i+1)%500==0:
            print(f"\t\tSet size is {len(S)}.")
    # print("\tS is: ", S)
    end_sel = time.time()
    print(f"\tTime taken for subset sel. using pre computed adv perturbation is {end_sel - start_sel}")
    end = time.time()
    print(f"REQUIRED NUMBERS------------Total time for distorted greeedy is {end-start}")
    return S

def calc_gc(defender, gamma, K, i, idx, S, gS=None):
    if gS==None:
        gS = defender.compute_g(S)
    if defender.adversary.attack_model_type in ["pgd"]:
        # Pointwise adversary
        pass
    return (1-gamma/K)**(K-i-1)*(defender.compute_g(S+[idx]) - gS) - defender.compute_c([idx])

def calc_gc_modular(defender, gamma, K, i, X, y):
    g_vals = defender.compute_g_on_batch(X, y)
    c_vals = defender.compute_c_on_batch(X, y)
    # print(g_vals.shape)
    # print(c_vals.shape)
    return (1-gamma/K)**(K-i-1)*(g_vals) - c_vals
