
import os
import numpy as np
import random
import argparse
import functools
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from utils import *

torch.autograd.set_detect_anomaly(True)
eps = 1e-10
infty = 1e10


class SampleModel(nn.Module):
    def __init__((self, features, sample_num, temperature, init, distance, balance=1.0, slice = None, batch_size = 100000, alidx = None):
        super(SampleModel, self).__init__()
        self.features = features
        self.total_num = features.shape[0]
        self.temperature = temperature
        self.sample_num = sample_num
        self.balance = balance

        self.init = init
        self.distance = distance
        
        self.alidx = alidx

        centroids = self.init_centroids()
        if init == 'hybrid':
            self.centroids_alidx = centroids[:len(alidx),:]
            self.centroids_new = nn.Parameter(centroids[len(alidx):,:]).cuda()
            # self.centroids = torch.cat( [self.centroids_alidx, self.centroids_new], dim=0 )
        else:
            self.centroids = nn.Parameter(centroids).cuda()
            # print('new centers shape', self.centroids.shape)
        

    def init_centroids(self):
        if self.init == "random":
            sample_ids = list(range(self.total_num))
            sample_ids = random.sample(sample_ids, self.sample_num)
        # elif self.init == "fps":
        #     dist_func = functools.partial(get_distance, type=self.distance)
        #     sample_ids = farthest_distance_sample(self.features, self.sample_num, dist_func)
        elif self.init == 'hybrid':
            sample_ids = self.alidx.copy()
            new = [i for i in range(self.total_num) if i not in sample_ids]
            sample_ids += random.sample(new, self.sample_num - len(self.alidx))

        centroids = self.features[sample_ids].clone()
        return centroids

    def get_loss(self):
        if self.init == 'hybrid':
            centroids0 = F.normalize(self.centroids_alidx, dim=1)
            centroids1 = F.normalize(self.centroids_new, dim=1)
            centroids = torch.cat( [centroids0, centroids1], dim=0 )
        else:
            centroids = F.normalize(self.centroids, dim=1)
        prod = torch.matmul(self.features, centroids.transpose(1, 0))  # (n, k)
        prod = prod / self.temperature
        prod_exp = torch.exp(prod)
        prod_exp_pos, pos_k = torch.max(prod_exp, dim=1)  # (n, )

        cent_prod = torch.matmul(centroids.detach(), centroids.transpose(1, 0))  # (k, k)
        cent_prod = cent_prod / self.temperature
        cent_prod_exp = torch.exp(cent_prod)
        cent_prob_exp_sum = torch.sum(cent_prod_exp, dim=0)  # (k, )

        J = torch.log(prod_exp_pos) - torch.log(prod_exp_pos + cent_prob_exp_sum[pos_k] * self.balance)
        J = -torch.mean(J)

        return J

def optimize_dist(features, sample_num, args, alidx):
    #  features: (n, c)
    # sample_model = SampleModel(features, sample_num, args.activeft_temperature, args.activeft_init, args.activeft_distance, args.activeft_balance, alidx)
    sample_model = SampleModel(features, sample_num, args.activeft_temperature, args.activeft_init, args.activeft_distance, args.activeft_balance, args.activeft_slice, args.activeft_batch_size, alidx)
    sample_model = sample_model.cuda()

    optimizer = optim.Adam(sample_model.parameters(), lr=args.activeft_lr)
    if args.activeft_scheduler != "none":
        if args.activeft_scheduler == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.activeft_max_iter, eta_min=1e-6)
        else:
            raise NotImplementedError

    for i in range(args.activeft_max_iter):
        loss = sample_model.get_loss()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if args.activeft_scheduler != "none":
            scheduler.step()
        lr = optimizer.param_groups[0]["lr"]
        print("Iter: %d, lr: %.6f, loss: %f" % (i, lr, loss.item()))

    if args.activeft_init == 'hybrid':
        centroids = torch.cat( [sample_model.centroids_alidx, sample_model.centroids_new], dim=0 )
    else:    
        centroids = sample_model.centroids
    centroids = F.normalize(centroids.detach(), dim=1)
    dist = torch.matmul(centroids, features.transpose(1, 0))  # (k, n)
    # dist = np.dot(centroids.cpu().numpy(), features.cpu().numpy().T)
    
    # _, sample_ids = torch.max(dist, dim=1)
    # sample_ids = sample_ids.cpu().numpy().tolist()
    # print(len(sample_ids), len(set(sample_ids)))

    # _, ids_sort = torch.sort(dist, dim=1, descending=True)
    ###if out of gpu memory
    dist = dist.cpu().numpy()
    ids_sort = (-dist).argsort(axis=1)
    
    #
    # sample_ids = set()
    #keep sample_ids order
    sample_ids = []
    for i in range(ids_sort.shape[0]):
        for j in range(ids_sort.shape[1]):
            if ids_sort[i, j].item() not in sample_ids:
                # sample_ids.add(ids_sort[i, j].item())
                sample_ids += [ids_sort[i, j].item()]
                break
    print(len(sample_ids))
    # sample_ids = list(sample_ids)
    return sample_ids


def ActiveFT_sampling(features, num_budget, alidx, args):
    
    features = torch.Tensor(features).cuda()
    features = F.normalize(features, dim=1)
    
    if len(alidx) == 0:
        alidx = optimize_dist(features, num_budget, args, alidx)
    else:
        args.activeft_init = 'hybrid'
        alidx = optimize_dist(features, num_budget + len(alidx), args, alidx)
    
    return alidx
