import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy


class CPL:
    def __init__(self, algo="default"):
        self.use_cpl = True
        self.algo = algo
        self.random_sample = False
        self.update_before_evaluate = 1
        self.accumulate_grad = True
        # opt hyper params
        self.accumulate_max = 128
        self.use_thresholds = 1
        self.random_rate = 0.6
        self.criterion="grad_consistency_0"
        

    def __call__(self, func):
        def wrap(instance, batch_x, *args, **kargs):
            if self.use_cpl:
                if not self.update_before_evaluate:
                    with torch.no_grad():
                        batch_pred = instance.classifier(instance.featurizer(batch_x))
                if self.use_thresholds:
                    with torch.no_grad():
                        batch_z = instance.featurizer(batch_x)
                        batch_p = instance.classifier(batch_z)
                                    
                        batch_yhat = F.one_hot(batch_p.argmax(1), num_classes=instance.num_classes).float()
                        batch_ent = softmax_entropy(batch_p)
                        batch_scores = F.softmax(batch_p,1)
                    supports = torch.cat([instance.supports.to(batch_x.device), batch_z])
                    labels = torch.cat([instance.labels.to(batch_x.device), batch_yhat])
                    ent = torch.cat([instance.ent.to(batch_x.device), batch_ent])
                    scores = torch.cat([instance.scores.to(batch_x.device), batch_scores])
                    y_hat = labels.argmax(dim=1).long()
                    thresholds = []
                    for i in range(instance.num_classes):
                        _, indices2 = torch.sort(ent[y_hat == i])
                        filter_K = min(instance.filter_K, len(indices2)-1)
                        thresholds.append(ent[y_hat==i][indices2][filter_K])
                else:
                    thresholds = []
                _indices = [ i for i in range(len(batch_x))]
                accumulate_num = 0
                while (len(_indices) != 0):
                    selected_indices, _indices = self.select_courses(instance, batch_x, _indices, random_sample=self.random_sample)
                    x = batch_x[selected_indices, :].view(len(selected_indices), *batch_x.shape[1:])
                    func(instance, x, *args, accumulate_grad=self.accumulate_grad, use_thresholds=self.use_thresholds, thresholds=thresholds, **kargs)
                    accumulate_num += len(selected_indices)
                    if self.accumulate_grad and accumulate_num >= self.accumulate_max:
                        instance.optimizer.step()
                        instance.optimizer.zero_grad()
                        accumulate_num = 0
                if self.update_before_evaluate:
                    with torch.no_grad():
                        batch_pred = instance.classifier(instance.featurizer(batch_x))
                return batch_pred
            else:
                if self.update_before_evaluate:
                    func(instance, batch_x, *args, **kargs)
                    return instance.classifier(instance.featurizer(batch_x))
                return func(instance, batch_x, *args, **kargs)

        return wrap

    def select_courses(self, instance, batch_x, left_courses, random_sample=False):
        if len(left_courses) <= 128:
            return left_courses, []

        if not random_sample:
            based_values = - selection_criterion(instance, batch_x, left_courses, self.criterion).view((-1,))
            k = int(128 * (1 - self.random_rate))
            _, selected_indices_in_left_courses = torch.topk(based_values, k = k, sorted=False)
            selected_courses = (torch.Tensor(left_courses)[selected_indices_in_left_courses.cpu()]).tolist()
            selected_courses += random.sample(list(set(left_courses) - set(selected_courses)), k=128-k)
            
        else:
            selected_courses = random.sample(left_courses, k=len(left_courses) // 2)

        left_courses = list(set(left_courses) - set(selected_courses))
        return selected_courses, left_courses


def selection_criterion(instance, batch_x, left_courses, criterion):
    x = batch_x[left_courses, :]
    z = instance.featurizer(x)
    supports = instance.supports.to(z.device)
    labels = instance.labels.to(z.device)
    weights = supports.T @ labels
    if criterion.startswith("uncertainty"):
        if criterion.endswith("0"):
            p = z @ torch.nn.functional.normalize(weights, dim=0)
            result_arr = softmax_entropy(p)
        if criterion.endswith("1"):
            p2 = instance.classifier(z)
            result_arr = softmax_entropy(p2)
        if criterion.endswith("2"):
            p = z @ torch.nn.functional.normalize(weights, dim=0)
            result_arr = 1 - torch.softmax(p, dim=1).max(dim=1)[0]
        if criterion.endswith("3"):
            p2 = instance.classifier(z)
            pred2 = torch.argmax(p2, dim=1)
            p4 = torch.zeros_like(p2)
            p4[torch.arange(len(pred2)), pred2] = 1.0
            result_arr = 1 - torch.softmax(p2, dim=1).max(dim=1)[0]
    if criterion.startswith("loss"):
        if criterion.endswith("0"):
            p = z @ torch.nn.functional.normalize(weights, dim=0)
            pred = torch.argmax(p, dim=1)
            p3 = torch.zeros_like(p)
            p3[torch.arange(len(pred)), pred] = 1.0
            result_arr = - (p3 * torch.log_softmax(p, dim=1)).sum(dim=1)
        if criterion.endswith("1"):
            p = z @ torch.nn.functional.normalize(weights, dim=0)
            pred = torch.argmax(p, dim=1)
            p3 = torch.zeros_like(p)
            p3[torch.arange(len(pred)), pred] = 1.0
            result_arr = - (p3 * torch.log_softmax(p, dim=1)).sum(dim=1)
    if criterion.startswith("grad_consistency"):
        weights.requires_grad = True
        p = z @ torch.nn.functional.normalize(weights, dim=0)
        p2 = instance.classifier(z)
        pred2 = torch.argmax(p2, dim=1)
        p4 = torch.zeros_like(p2)
        p4[torch.arange(len(pred2)), pred2] = 1.0
        loss = - (p4 * torch.log_softmax(p, dim=1)).sum(dim=1)
        import torch.autograd as autograd
        weights_grads = [ autograd.grad(loss[i], weights, retain_graph=True)[0] for i in range(len(loss))]
        weights_grads_mean = sum(weights_grads) / len(weights_grads)
        if criterion.endswith("0"):
            result_arr = torch.Tensor([ - torch.norm(torch.abs(weights_grads[i].T - weights_grads_mean.T), p=2) for i in range(len(weights_grads)) ]).to(z.device)
        if criterion.endswith("1"):
            result_arr = torch.Tensor([ - ((weights_grads[i].T * weights_grads_mean.T).sum(dim=1) / (torch.norm(weights_grads[i].T, p=2, dim=1) * torch.norm(weights_grads_mean.T, p=2, dim=1))).mean() for i in range(len(weights_grads)) ]).to(z.device)
        if criterion.endswith("2"):
            result_arr = torch.Tensor([ - torch.norm(torch.abs(weights_grads[i].T[pred2[i],:] - weights_grads_mean.T[pred2[i],:])) for i in range(len(weights_grads)) ]).to(z.device)
        if criterion.endswith("3"):
            result_arr = torch.Tensor([ - (weights_grads[i].T[pred2[i],:]*weights_grads_mean.T[pred2[i],:]).sum() / (torch.norm(weights_grads[i].T[pred2[i],:]) * torch.norm(weights_grads_mean.T[pred2[i],:]).sum()) for i in range(len(weights_grads)) ]).to(z.device)
    return result_arr

@torch.jit.script
def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
    """Entropy of softmax distribution from logits."""
    return -(x.softmax(1) * x.log_softmax(1)).sum(1)


@torch.jit.script
def softmax_kl_loss(input_logits: torch.Tensor, target_logits: torch.Tensor) -> torch.Tensor:
    """Takes softmax on both sides and returns KL divergence

    Note:
    - Returns the sum over all examples. Divide by the batch size afterwards
      if you want the mean.
    - Sends gradients to inputs but not the targets.
    """
    assert input_logits.size() == target_logits.size()
    input_log_softmax = F.log_softmax(input_logits, dim=1)
    target_softmax = F.softmax(target_logits, dim=1)

    kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='none')
    return kl_div


