import pdb

from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
import numpy as np
from tqdm import tqdm
import torch
from metric import TaskIncrementalMetric
from torch.nn import functional as F
import wandb
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
def logging(x_name, x_value, y_name, y_value, args):
    if args.wandb:
        wandb.define_metric(x_name)
        wandb.define_metric(y_name, step_metric=x_name)
        wandb.log({
            x_name: x_value,
            y_name: y_value
        })

class FrozenNMC(object):
    def __init__(self, args):
        self.args = args
        self.mean_feature = [0] * args.num_classes
        self.sample_count = [0] * args.num_classes
        self.sum_feature = [0] * args.num_classes
        self.num_classes = args.num_classes
        self.metric = TaskIncrementalMetric(args)
    def only_evaluation(self, model, dataset, task):
        pass
    def train(self, model, dataset, task):
        pass
    def evaluation(self, model, dataset, task):
        X = []
        y = []
        dataset.switch_to_train()
        train_dataloader = DataLoader(dataset, batch_size=self.args.batch_size,num_workers=self.args.workers)
        for (image, label, _) in tqdm(train_dataloader,desc=f"Compute mean feature for {task}",total=len(train_dataloader)):
            image = image.cuda()
            feature = F.normalize(model.encode_image(image).detach())
            X.append(feature)
            y.append(label)

        X = torch.cat(X,dim=0)
        y = torch.cat(y,dim=0)
        for i in range(0, self.num_classes):
            image_class_mask = (y == i)
            class_total = torch.sum(X[image_class_mask], dim=0)
            class_count = torch.sum(image_class_mask)

            self.sum_feature[i] = class_total if task == 0 else (self.sum_feature[i] + class_total)
            self.sample_count[i] += class_count
            self.mean_feature[i] = self.sum_feature[i] / self.sample_count[i]
        mean_feature = torch.stack(self.mean_feature,dim=0)
        dataset.switch_to_test()
        for t in range(task+1):
            dataset.switch_task(t)
            test_dataloader = DataLoader(dataset, batch_size=self.args.batch_size,num_workers=self.args.workers)
            for (image, label, _) in tqdm(test_dataloader, desc=f"Evaluation for {t}",
                                          total=len(test_dataloader)):
                image = image.cuda()
                label = label.cuda()
                with torch.no_grad():
                    feature = F.normalize(model.encode_image(image).detach())
                    distance = -((feature.unsqueeze(1) - mean_feature.unsqueeze(0))**2).sum(-1)

                acc = accuracy(distance, label, topk=(1,))[0]

                self.metric.update(task,t,acc,n=image.size(0))
            self.metric.update_metric(task,t)
        print (f' * End evaluation: task accuracy top1 {self.metric.average_accuracy[task]:.2f}')
        self.metric.print_matrix('Per task accuracy')
        if task == self.args.num_tasks:
            print(f' * End evaluation: average accuracy top1 {self.metric.average_accuracy}')
        if self.args.report_to:
            logging('task',task,'average accuracy',self.metric.average_accuracy[task],self.args)
    def save_checkpoint(self,model, task, args):
        pass










