import torch
import math
from tqdm import tqdm
import os 

@torch.jit.script
def get_square(mat):
    return mat**2


def get_cdist(V):
    # ct = time.time()
    dist_mat = torch.cdist(V, V)
    # print("Distance Matrix construction time ", time.time()-ct)
    return get_square(dist_mat)


def get_rbf_kernel(dist_mat, kw):
    # ct = time.time()
    sim = torch.exp(-dist_mat/(kw*dist_mat.mean()))
    # print("Similarity Kernel construction time ", time.time()-ct)
    return sim


def FL_evaluation(sim, sets, norm=1.):
    foo = sim.unsqueeze(0)[:, sets].squeeze()
    assert foo.shape[0] == len(sets)
    # foo = sim[sets]
    foo = ((torch.amax(foo, dim=1)).squeeze()).sum(dim=-1)/norm
    return foo

def get_features(model, X, device, bsz=2048):
    h = []
    num_batches = math.ceil(len(X)/bsz)
    with torch.no_grad():
        model.eval()
        for i in (range(num_batches)):
            input = X[i*bsz: (i+1)*bsz].to(device).float()
            _, output = model(input, last=True)

            h.append(output)

    h = torch.cat(h, dim=0)
    return h


class facilityLocation:
    def __init__(self, V, device, args, labels=None):
        self.sim = None
        if args['dset'] == 'IN100':    
            import numpy as np
            path = os.path.join(os.path.join(args["root"], args["dset"]), "similarity")
            N = 130000
            num_shards = 65
            self.sim = []
            print("Loading matrix in shards..")
            for ii in tqdm(range(num_shards)):
                self.sim.append(np.load(f"{path}/IN100_sim_train_sharded_{ii}.npy"))
            print("Loaded all shards. Concatenating now..")
            self.sim = np.concatenate(self.sim, axis=0)
            self.sim =  torch.from_numpy(self.sim)

        self.device = device
        kw = args['kw']
        if self.sim is None:
            V = V.to(self.device)
            print(f"Generating {len(V)},{len(V)} RBF similarity matrix")
            self.sim = get_rbf_kernel(get_cdist(V.squeeze()), kw).cpu()
            V = None
            del (V)
            torch.cuda.empty_cache()

    def get_FL_evals_nested(self, D_M_full, nesting_list, MAX=0, SCMI=False, bsz=2048):
        # Nesting list: [10, 20, 30, ,..., 100] = 10
        # A : (N, 100), B = (N, 100)
        # G: true function we compute margin on, G(A) = FL(A)/FL(V)
        print(self.device)
        self.sim = self.sim.to(self.device)
        f_V = len(self.sim) #torch.max(self.sim, dim=-1)[0].sum()

        N = len(D_M_full)
        num_batches = N//bsz if N % bsz == 0 else N//bsz + 1

        eval_full = [[] for i in range(len(nesting_list))]

        for i in (range(num_batches)):
            A = D_M_full[i*(bsz):(i+1)*bsz]

            for j in range(len(nesting_list)):
                A_ = A[:, :nesting_list[j]]
                # A_ (N, nesting_list[j])
                n = len(A_)
                FL_M = FL_evaluation(self.sim, A_, f_V)

                eval_full[j].append(FL_M)

        eval_full = [torch.cat(m, dim=-1) for m in eval_full]
        eval_full = torch.stack(eval_full, dim=0)
        

        self.sim = self.sim.cpu()
        eval_full = eval_full.cpu()
        torch.cuda.empty_cache()
        return eval_full

    def get_margin_FL_nested(self, D_M_full, D_E_full, nesting_list, Y=None, matroid_rank=-1, rank_tradeoff=0., target_responsibility=1., bsz=40):
        # Nesting list: [10, 20, 30, ,..., 100] = 10
        # A : (N, 100), B = (N, 100)
        # G: true function we compute margin on, G(A) = FL(A)/FL(V)
        self.sim = self.sim.to(self.device)
        f_V = torch.max(self.sim, dim=-1)[0].sum()
        if matroid_rank > -1:
            Y = Y.to(self.device)

        N = len(D_M_full) # ; bsz=1200
        num_batches = N//bsz if N % bsz == 0 else N//bsz + 1

        margin_full = [[] for i in range(len(nesting_list))]
        FL_M_full = [[] for i in range(len(nesting_list))]
        FL_E_full = [[] for i in range(len(nesting_list))]

        for i in tqdm(range(num_batches)):
            A = D_M_full[i*(bsz):(i+1)*bsz]
            B = D_E_full[i*(bsz):(i+1)*bsz]

            for j in range(len(nesting_list)):
                A_ = A[:, :nesting_list[j]]
                B_ = B[:, :nesting_list[j]]
                # A_ (N, nesting_list[j])
                n = len(A_)
                # FL_M_E = FL_evaluation(
                #     self.sim, torch.cat([A_, B_], dim=0), f_V)
                FL_M_E = FL_evaluation(
                    self.sim, torch.cat([A_, B_], dim=0), f_V)


                # Partition Matroid Rank Computation with budgets of 1
                if matroid_rank > -1: 
                    Y_M_E = Y[torch.cat([A_, B_])]
                    rank_M_E = (torch.Tensor([len(torch.unique(Y_M_E[i])) for i in range(
                        len(FL_M_E))])/100).to(self.device, non_blocking=True)
                else:
                    rank_M_E = torch.zeros(len(FL_M_E), device=self.device)
                margin_full[j].append(
                    target_responsibility*(FL_M_E[n:]-FL_M_E[:n]) + rank_tradeoff*(rank_M_E[n:]-rank_M_E[:n]))
                FL_M_full[j].append(
                    target_responsibility*FL_M_E[:n] + (rank_tradeoff*rank_M_E)[:n])
                FL_E_full[j].append(
                    target_responsibility*FL_M_E[n:] + (rank_tradeoff*rank_M_E)[n:])

        margin_full = [torch.cat(m, dim=-1) for m in margin_full]
        FL_M_full = [torch.cat(m, dim=-1) for m in FL_M_full]
        FL_E_full = [torch.cat(m, dim=-1) for m in FL_E_full]

        margin_full = torch.stack(margin_full, dim=0)
        FL_M_full = torch.stack(FL_M_full, dim=0)
        FL_E_full = torch.stack(FL_E_full, dim=0)

        self.sim = self.sim.cpu()
        margin_full = margin_full.cpu()
        FL_M_full = FL_M_full.cpu()
        FL_E_full = FL_E_full.cpu()
        torch.cuda.empty_cache()
        return margin_full, FL_M_full, FL_E_full
