import logging
import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim
import torch.utils.data

from linear_probe.feature_loader import load_features
from utils.utils import AverageMeter, accuracy


def _set_seed(seed):
    print("Set seed", seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True  # This will slow down training.


class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""

    def __init__(self, dim, num_class=100):
        super(LinearClassifier, self).__init__()
        self.num_class = num_class
        self.linear = nn.Linear(dim, num_class)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        return self.linear(x)


class Linear_Probe:
    def __init__(self, cfg):
        # self.feature_dir = feature_dir
        self.cfg = cfg
        self.train_loader, self.val_loader, self.test_loader = load_features(cfg)
        # Which split is the target for knn? (val, test or both)
        logging.basicConfig(
            filename=f"{cfg.exp_path}/{cfg.lr}_{cfg.wd}.log",
            level=logging.INFO,
            format="%(asctime)s - %(levelname)s - %(message)s",
        )
        self.test_modes = []
        if self.val_loader is not None:
            self.test_modes.append("val")
        if self.test_loader is not None:
            self.test_modes.append("test")
        assert len(self.test_modes) > 0, "Must test on something"
        # How many layers do we wanna probe on
        self.len_layers = self.train_loader.dataset.len_layers()
        # knn usefule variables
        self.K = [1]

    def _probe(self, train_loader, test_loader, layer_name, feat_dim, progress_bar=False):
        if "resnet50" in self.cfg.model_name:
            # not use batchnorm for resnet50
            classifier = LinearClassifier(dim=feat_dim, num_class=self.cfg.num_classes)
        else:
            classifier = nn.Sequential(
                nn.BatchNorm1d(feat_dim, affine=False),
                LinearClassifier(dim=feat_dim, num_class=self.cfg.num_classes),
            )
        classifier = torch.nn.DataParallel(classifier).cuda()
        optimizer = torch.optim.AdamW(classifier.parameters(), self.cfg.lr, weight_decay=self.cfg.wd)
        criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1).cuda()
        # criterion = torch.nn.CrossEntropyLoss().cuda()

        for e in range(self.cfg.epochs):
            total_loss = 0
            classifier.train()
            train_correct = 0
            train_count = 0
            for train_idx, (train_features, train_targets) in enumerate(train_loader):
                train_features, train_targets = train_features.cuda(), train_targets.cuda()
                # train_targets = train_targets - 50
                optimizer.zero_grad()
                outputs = classifier(train_features)
                preds = outputs.argmax(dim=1)
                train_correct += preds.eq(train_targets).sum().item()
                train_count += train_targets.size(0)

                loss = criterion(outputs, train_targets)
                total_loss += loss.item()
                loss.backward()
                optimizer.step()
            train_acc = round(train_correct / train_count, 2)

            total_val_correct = 0
            total_val = 0
            # classifier.eval()
            # with torch.no_grad():
            #     for val_idx, (val_features, val_targets) in enumerate(test_loader):
            #         val_features, val_targets = val_features.cuda(), val_targets.cuda()
            #         outputs = classifier(val_features)
            #         _, predicted = outputs.max(1)
            #         total_val_correct += predicted.eq(val_targets).sum().item()
            #         total_val += val_targets.size(0)
            # val_acc = round(total_val_correct / total_val, 2)
            # logging.info(f"Epoch: {e}, Loss: {total_loss}, Train Acc: {train_acc}, Val Acc: {val_acc}")

        for test_idx, (test_features, test_targets) in enumerate(test_loader):
            test_features, test_targets = test_features.cuda(), test_targets.cuda()
            # test_targets = test_targets - 50
            classifier.eval()
            outputs = classifier(test_features)
            acc1, acc5 = accuracy(outputs, test_targets, topk=(1, 5))
            # update meters
            self.acc_meters[1].update(acc1, test_targets.size(0))

        # print("Layer:", layer_name, "Feature Dim:", test_features.shape)

    def on_probe_start(self, layer_name):
        # initialize meters
        self.acc_meters = {k: AverageMeter(f"Acc@{k}", "6:4") for k in self.K}

    def on_probe_end(self, layer_idx, mode):
        self.accs = {k: meter.avg for k, meter in self.acc_meters.items()}
        for k, meter in self.acc_meters.items():
            meter.reset()

        # print(self.accs[k])
        return self.accs[k]

    # @torch.no_grad()
    def probe(self):
        lp_acc = []
        for l_idx in range(0, self.len_layers):
            # l_idx = self.len_layers - 1
            layer_name, feat_dim = self.train_loader.dataset.set_layer(l_idx)
            if feat_dim != self.cfg.num_classes:
                print(f"processing layer {layer_name}")
                self.on_probe_start(layer_name)
                for test_mode in self.test_modes:
                    if test_mode == "val":
                        self.val_loader.dataset.set_layer(l_idx)
                        self._probe(self.train_loader, self.val_loader, layer_name, feat_dim)
                        top1_acc = self.on_probe_end(l_idx, "val")
                        lp_acc = np.append(lp_acc, top1_acc.cpu().numpy())
                        logging.info(f"Layer: {layer_name}, Acc: {top1_acc.cpu().numpy()}")
                    elif test_mode == "test":
                        self.test_loader.dataset.set_layer(l_idx)
                        self._probe(self.train_loader, self.test_loader, layer_name)
                        self.on_probe_end(l_idx, "test")
        # logging.info(f"Acc: {lp_acc}")
        return lp_acc
