import datetime
import gc
import time

from sklearn.metrics import confusion_matrix
import torch
import logging

from tqdm import tqdm

from methods._trainer import _Trainer

logger = logging.getLogger()


class ContinualCLIP(_Trainer):

    def __init__(self, **kwargs):
        super(ContinualCLIP, self).__init__(**kwargs)

    def online_step(self, images, labels, idx):
        self.add_new_class(labels)
        self.model.update_class_names(self.exposed_classes_names)

        # zero-shot, don't need to train
        del (images, labels)
        gc.collect()
        return -1, -1

    def report_training(self, sample_num, train_loss, train_acc):
        pass

    def report_test(self, sample_num, avg_loss, avg_acc):
        logging.info(
            f"Test | Sample # {sample_num} | test_acc {avg_acc:.4f} | "
            f"running_time {datetime.timedelta(seconds=int(time.time() - self.start_time))} | "
            f"ETA {datetime.timedelta(seconds=int((time.time() - self.start_time) * (self.total_samples-sample_num) / sample_num))}"
        )

    def online_train(self, data):
        pass

    def online_before_task(self, task_id):
        pass

    def online_after_task(self, task_id):
        pass

    def online_evaluate(self, test_loader):
        total_correct, total_num_data, total_loss = 0.0, 0.0, 0.0
        correct_l = torch.zeros(self.n_classes)
        num_data_l = torch.zeros(self.n_classes)
        label, pred_list = [], []
        image_features, text_features = [], None

        self.model.eval()
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                x, y = data
                for j in range(len(y)):
                    y[j] = self.exposed_classes.index(y[j].item())

                x = x.to(self.device)
                y = y.to(self.device)

                logit, _image_features, text_features = self.model(x)
                pred = torch.argmax(logit, dim=-1)
                _, preds = logit.topk(self.topk, 1, True, True)
                total_correct += torch.sum(preds == y.unsqueeze(1)).item()
                total_num_data += y.size(0)

                xlabel_cnt, correct_xlabel_cnt = self._interpret_pred(y, pred)
                correct_l += correct_xlabel_cnt.detach().cpu()
                num_data_l += xlabel_cnt.detach().cpu()

                label += y.tolist()
                pred_list += pred.tolist()
                image_features.append(_image_features.cpu())

        avg_acc = total_correct / total_num_data
        avg_loss = total_loss / len(test_loader)
        cls_acc = (correct_l / (num_data_l + 1e-5)).numpy().tolist()
        cm = confusion_matrix(label, pred_list)

        eval_dict = {
            "avg_loss": avg_loss,
            "avg_acc": avg_acc,
            "cls_acc": cls_acc,
            "confusion_matrix": cm.tolist(),
            "feature": {
                'image_features': torch.cat(image_features, 0).numpy(),
                'text_features': text_features.cpu().numpy(),
                'labels': label,
                'label_names': self.exposed_classes_names
            }
        }
        return eval_dict

    @torch.no_grad()
    def offline_evaluate(self, zs_test_loader, zs_classes, zs_meta, zsaf=True):
        result = []
        text_features = []
        total_correct, total_num_data = 0.0, 0.0
        self.model.eval()
        for idx in range(len(zs_meta)):
            zs_correct, zsaf_correct, zs_num_data = 0.0, 0.0, 0.0
            _test_loader = zs_test_loader[idx]
            _classes = zs_classes[idx]
            _meta = zs_meta[idx]

            # text_features
            _text_tokens = self.model.labels_tokenize(_classes).to(self.device)
            text_features.append(self.model.forward_text(_text_tokens))

            # image_feature
            for data in tqdm(_test_loader):
                x, y = data
                x = x.to(self.device)
                y = y.to(self.device)
                _image_features = self.model.forward_image(x)
                # zs
                _text_features = text_features[-1]
                logit = self.model.forward_head(_image_features,
                                                _text_features)
                pred = torch.argmax(logit, dim=-1)
                zs_correct += torch.sum(pred.squeeze() == y.squeeze()).item()
                zs_num_data += y.size(0)
                # zsaf
                if zsaf and idx > 0:
                    _text_features = torch.cat(
                        (text_features[0], text_features[-1]), 0)
                    logit = self.model.forward_head(_image_features,
                                                    _text_features)
                    pred = torch.argmax(logit, dim=-1)
                    offset = text_features[0].shape[0]
                    zsaf_correct += torch.sum(
                        pred.squeeze() == (y.squeeze() + offset)).item()
                # total
                _text_features = torch.cat(text_features, 0)
                logit = self.model.forward_head(_image_features,
                                                _text_features)
                pred = torch.argmax(logit, dim=-1)
                offset = _text_features.shape[0] - text_features[-1].shape[0]
                total_correct += torch.sum(pred.squeeze() == (y.squeeze() +
                                                              offset)).item()
                total_num_data += y.size(0)

            zs_acc = zs_correct / zs_num_data
            zsaf_acc = zsaf_correct / zs_num_data
            total_acc = total_correct / total_num_data
            result.append({
                'name': _meta[0],
                'zs_acc': zs_acc,
                'zsaf_acc': zsaf_acc,
                'total_acc': total_acc,
            })
            print(
                '{} | zs_acc: {:.2f}% | zsaf_acc: {:.2f}% | total_acc: {:.2f}% |'
                .format(_meta[0], zs_acc * 100, zsaf_acc * 100,
                        total_acc * 100))

        return result
