# -*- coding: utf-8 -*-

import copy
import math
import os
from sklearn.cluster import MiniBatchKMeans
import numpy as np
from torch.autograd import Variable
from tqdm import tqdm
import csv
from .strategy import Strategy
from sklearn.cluster import KMeans
import torch
import torch.nn.functional as F


class AlphaMixSampling(Strategy):
    def __init__(self, Train_Args,Args, model, unlabeledOriginLabels, unlabeledOutputList, labeledEmbeddingList,
                 unlabeledEmbeddingList, unlabeledImgPath, add_ratio,labeledProtoList,num_classes, unlabeledTarget):
        # super(AlphaMixSampling, self).__init__(Train_Args, unlabeledOriginLabels, unlabeledImgPath,unlabeledTarget)
        super(AlphaMixSampling, self).__init__(Train_Args, unlabeledOriginLabels,
                                    unlabeledImgPath, add_ratio,
                                    num_classes, unlabeledTarget)
        self.labeledEmbeddingList = labeledEmbeddingList
        self.unlabeledEmbeddingList = unlabeledEmbeddingList
        self.labeledProtoList = labeledProtoList
        self.unlabeledOriginLabels = unlabeledOriginLabels
        self.unlabeledOutputList = unlabeledOutputList
        self.unlabeledImgPath = unlabeledImgPath
        self.model = model
        self.queryCount = 0
        self.DEVICE = Args.device
        self.addRatio = add_ratio
        self.ALPHA_CAP = 0.03125
        self.APLHA_OPT = False 
        self.ALPHA_CLOSED_FORM_APPROX = False
        self.ALPHA_LEARNING_RATE = 0.1
        self.ALPHA_CLF_COEF = 1.0
        self.ALPHA_L2_COEF = 0.01
        self.ALPHA_LEARNING_ITERA = 5
        self.ALPHA_LEARN_BATCH_SIZE = 1000000
        self.NUM_TRAIN_SET = Train_Args.num_train_set
        self.NUM_CLASSES = num_classes
        
        self.dataset_name = Train_Args.dataset_name
        self.classifier_name = Train_Args.classifier_name
        self.select_strategy = Train_Args.select_strategy
        self.select_type = Train_Args.select_type
    def query(self):
        self.queryCount += 1
        n = int(self.NUM_TRAIN_SET * self.addRatio)
        ulbProbs, orgUlbEmbedding = self.unlabeledOutputList, np.array(self.unlabeledEmbeddingList)
        predOne = torch.tensor(np.argmax(ulbProbs, axis=1))

        ulbEmbedding = torch.tensor(orgUlbEmbedding)

        unlabeledSize = ulbEmbedding.size(0)
        embeddingSize = ulbEmbedding.size(1)

        minAlphas = torch.ones((unlabeledSize, embeddingSize), dtype=torch.float)
        candidate = torch.zeros(unlabeledSize, dtype=torch.bool)

        if self.ALPHA_CLOSED_FORM_APPROX:
            varEmb = Variable(ulbEmbedding, requires_grad=True).to(self.DEVICE)
            out, _ = self.model.linear(varEmb, embedding=True)
            loss = F.cross_entropy(out, predOne.to(self.DEVICE))
            grads = torch.autograd.grad(loss, varEmb)[0].data.cpu()
            del loss, varEmb, out
        else:
            grads = None

        alphaCap = 0.
        print('Find candidate set...')
        if self.APLHA_OPT:
            print('Learning alpha...')
        while alphaCap < 1.0:
            alphaCap += self.ALPHA_CAP

            tmpPredChange, tmpMinAlphas = \
                self.findCandidateSet(ulbEmbedding, predOne, alphaCap=alphaCap, grads=grads)

            isChanged = minAlphas.norm(dim=1) >= tmpMinAlphas.norm(dim=1)
            minAlphas[isChanged] = tmpMinAlphas[isChanged]
            candidate += tmpPredChange

            # print('With ALPHA_CAP set to %f, number of inconsistencies: %d' % (
            # ALPHA_CAP, int(tmpPredChange.sum().item())))


        filepath = './Selcetion/{}/{}/{}/{}/'.format(self.dataset_name, self.classifier_name, self.select_strategy,
                                                    self.select_type)
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        ft = open(str('{}{}_{}.csv'.format(filepath, self.dataset_name, self.select_ratio)), 'w', newline='')
        ftCsv = csv.writer(ft)
        # ftCsv = self.makeCsvPath()
        print('ALFA-Mix Sample...')
        if candidate.sum() > 0:
            print('Number of inconsistencies: %d' % (int(candidate.sum().item())))

            print('alpha_mean_mean: %f' % minAlphas[candidate].mean(dim=1).mean().item())
            print('alpha_std_mean: %f' % minAlphas[candidate].mean(dim=1).std().item())
            print('alpha_mean_std %f' % minAlphas[candidate].std(dim=1).mean().item())

            cAlpha = F.normalize(ulbEmbedding[candidate].view(candidate.sum(), -1), p=2, dim=1).detach()

            selectedIdxs = self.sample(min(n, candidate.sum().item()), feats=cAlpha)
            # u_selected_idxs = candidate.nonzero()[0][selectedIdxs]
            selectedIdxs = np.array(range(len(candidate)))[candidate][selectedIdxs]
        else:
            selectedIdxs = np.array([], dtype=np.int)

        if len(selectedIdxs) < n:
            remained = n - len(selectedIdxs)
            idxUlb = np.zeros(unlabeledSize, dtype=bool)
            idxUlb[selectedIdxs] = True
            selectedIdxs = np.concatenate([selectedIdxs, np.random.choice(np.where(idxUlb == 0)[0], remained)])
            print('picked {}d samples from RandomSampling.'.format(remained))
        dirList = []
        lableList = []
        for idx in tqdm(selectedIdxs):
            ftCsv.writerow([self.unlabeledImgPath[idx]] + [str(self.unlabeledOriginLabels[idx])])
            dirList.append(self.unlabeledImgPath[idx])
            lableList.append(str(self.unlabeledOriginLabels[idx]))
            if len(lableList) >= self.addRatio * self.NUM_TRAIN_SET:
                break
        selectData = list(zip(dirList, lableList))

        return selectData

    def findCandidateSet(self, ulbEmbedding, predOne, alphaCap, grads):

        unlabeledSize = ulbEmbedding.size(0)
        embeddingSize = ulbEmbedding.size(1)

        minAlphas = torch.ones((unlabeledSize, embeddingSize), dtype=torch.float)
        predChange = torch.zeros(unlabeledSize, dtype=torch.bool)

        if self.ALPHA_CLOSED_FORM_APPROX:
            alphaCap /= math.sqrt(embeddingSize)
            grads = grads.to(self.DEVICE)

        for i in tqdm(range(self.NUM_CLASSES), leave=False):
            anchorI = torch.tensor(self.labeledProtoList[i])
            if self.ALPHA_CLOSED_FORM_APPROX:
                embedI, ulbEmbed = anchorI.to(self.DEVICE), ulbEmbedding.to(self.DEVICE)
                alpha = self.calculateOptimumAlpha(alphaCap, embedI, ulbEmbed, grads)
                embeddingMix = (1 - alpha) * ulbEmbed + alpha * embedI
                out, _ = self.model.linear(embeddingMix)
                out = out.detach().cpu()
                alpha = alpha.cpu()
                pc = out.argmax(dim=1) != predOne
            else:
                alpha = self.generateAlpha(unlabeledSize, embeddingSize, alphaCap)
                if self.APLHA_OPT:
                    alpha, pc = self.learnAlpha(ulbEmbedding, predOne, anchorI, alpha, alphaCap)
                else:
                    embeddingMix = (1 - alpha) * ulbEmbedding + alpha * anchorI
                    out = self.model.linear(embeddingMix.to(self.DEVICE).float())
                    out = out.detach().cpu()
                    pc = out.argmax(dim=1) != predOne
            torch.cuda.empty_cache()

            alpha[~pc] = 1.
            predChange[pc] = True
            isMin = minAlphas.norm(dim=1) > alpha.norm(dim=1)
            minAlphas[isMin] = alpha[isMin]

        return predChange, minAlphas

    def calculateOptimumAlpha(self, eps, lbEmbedding, ulbEmbedding, ulbGrads):
        z = (lbEmbedding - ulbEmbedding)  # * ulb_grads
        alpha = (eps * z.norm(dim=1) / ulbGrads.norm(dim=1)).unsqueeze(dim=1).repeat(1, z.size(1)) * ulbGrads / (
                z + 1e-8)

        return alpha

    # def sample(self, n, feats):
    #     feats = feats.numpy()
    #     clusterLearner = KMeans(n_clusters=n)
    #     clusterLearner.fit(feats)
    #     clusterIdxs = clusterLearner.predict(feats)
    #     centers = clusterLearner.cluster_centers_[clusterIdxs]
    #     dis = (feats - centers) ** 2
    #     dis = dis.sum(axis=1)
    #     return np.array(
    #         [np.arange(feats.shape[0])[clusterIdxs == i][dis[clusterIdxs == i].argmin()] for i in tqdm(range(n)) if
    #          (clusterIdxs == i).sum() > 0])
    def sample(self, n, feats):
        feats = feats.numpy()
        print(f"[INFO] Start KMeans clustering with {n} clusters, shape: {feats.shape}")

        # clusterLearner = KMeans(n_clusters=n, random_state=0)
        # clusterLearner.fit(feats)
        clusterLearner = MiniBatchKMeans(n_clusters=n, batch_size=1024, random_state=0)
        clusterLearner.fit(feats)
        clusterIdxs = clusterLearner.predict(feats)
        centers = clusterLearner.cluster_centers_[clusterIdxs]
        dis = (feats - centers) ** 2
        dis = dis.sum(axis=1)

        selected = []

        for i in tqdm(range(n), desc="Selecting representative samples"):
            in_cluster = np.where(clusterIdxs == i)[0]
            if len(in_cluster) == 0:
                continue
            closest = in_cluster[dis[in_cluster].argmin()]
            selected.append(closest)

        return np.array(selected)
    def retrieveAnchor(self, embeddings, count):
        return embeddings.mean(dim=0).view(1, -1).repeat(count, 1)

    def generateAlpha(self, size, embeddingSize, alphaCap):
        alpha = torch.normal(
            mean=alphaCap / 2.0,
            std=alphaCap / 2.0,
            size=(size, embeddingSize))

        alpha[torch.isnan(alpha)] = 1

        return self.clampAlpha(alpha, alphaCap)

    def clampAlpha(self, alpha, alphaCap):
        return torch.clamp(alpha, min=1e-8, max=alphaCap)

    def learnAlpha(self, orgEmbed, labels, anchorEmbed, alpha, alphaCap):

        labels = labels.to(self.DEVICE)
        minAlpha = torch.ones(alpha.size(), dtype=torch.float)
        predChanged = torch.zeros(labels.size(0), dtype=torch.bool)

        lossFunc = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.linear.eval()

        for i in range(self.ALPHA_LEARNING_ITERA):
            totNrm, totLoss, totClfLoss = 0., 0., 0.
            for b in range(math.ceil(float(alpha.size(0)) / self.ALPHA_LEARN_BATCH_SIZE)):
                self.model.linear.zero_grad()
                startIdx = b * self.ALPHA_LEARN_BATCH_SIZE
                endIdx = min((b + 1) * self.ALPHA_LEARN_BATCH_SIZE, alpha.size(0))

                l = alpha[startIdx:endIdx]
                l = torch.autograd.Variable(l.to(self.DEVICE), requires_grad=True)
                opt = torch.optim.Adam([l], lr=self.ALPHA_LEARNING_RATE / (
                    1. if i < self.ALPHA_LEARNING_ITERA * 2 / 3 else 10.))
                e = orgEmbed[startIdx:endIdx].to(self.DEVICE)
                c_e = anchorEmbed[startIdx:endIdx].to(self.DEVICE)
                embeddingMix = (1 - l) * e + l * c_e

                out = self.model.linear(embeddingMix.float())

                labelChange = out.argmax(dim=1) != labels[startIdx:endIdx]

                tmpPc = torch.zeros(labels.size(0), dtype=torch.bool).to(self.DEVICE)
                tmpPc[startIdx:endIdx] = labelChange
                predChanged[startIdx:endIdx] += tmpPc[startIdx:endIdx].detach().cpu()

                tmpPc[startIdx:endIdx] = tmpPc[startIdx:endIdx] * (
                        l.norm(dim=1) < minAlpha[startIdx:endIdx].norm(dim=1).to(self.DEVICE))
                minAlpha[tmpPc] = l[tmpPc[startIdx:endIdx]].detach().cpu()

                clfLoss = lossFunc(out, labels[startIdx:endIdx].to(self.DEVICE))

                l2Norm = torch.norm(l, dim=1)

                clfLoss *= -1

                loss = self.ALPHA_CLF_COEF * clfLoss + self.ALPHA_L2_COEF * l2Norm
                loss.sum().backward(retain_graph=True)
                opt.step()

                l = self.clampAlpha(l, alphaCap)

                alpha[startIdx:endIdx] = l.detach().cpu()

                totClfLoss += clfLoss.mean().item() * l.size(0)
                totLoss += loss.mean().item() * l.size(0)
                totNrm += l2Norm.mean().item() * l.size(0)

                del l, e, c_e, embeddingMix
                torch.cuda.empty_cache()

        return minAlpha.cpu(), predChanged.cpu()
