from tqdm import tqdm
import torch
import torch.nn.functional as F
import time
from sklearn.svm import LinearSVC
from copy import deepcopy
import numpy as np
from collections import Counter
from benchmarking_utils.clustering_metrics import mapper
from joblib import parallel_backend

class LinearClassifierHandler:
    def __init__(self, feature_bank, targets_bank, from_layer=2, end_layer=5):
        self.feature_bank = feature_bank
        self.targets_bank = targets_bank
        self.from_layer = from_layer
        self.end_layer = end_layer
        self.lin_classifiers = {}
        self.sklearn_solutions = {}
        self.num_classes = 100

        self.sklearn_classifiers = [("LinearSVC_orig", LinearSVC(random_state=0, max_iter=150)),
                                    ("LinearSVC_superclass", LinearSVC(random_state=0, max_iter=150))]


        self.compute_sklearn_classifiers()

    def compute_sklearn_classifiers(self):
        print(f"beginning compute classifiers")
        for classifier_idx, cur_classifier in enumerate(self.sklearn_classifiers):            
            classifier_name, classifier_instance = cur_classifier
            self.sklearn_solutions[classifier_name] = {}
            for idx in tqdm(range(self.from_layer, self.end_layer)):
                cur_features, cur_targets = self.feature_bank[idx], self.targets_bank[idx]
                if "superclass" in classifier_name:
                    cur_targets = mapper.coarse_labels[cur_targets.cpu()]
                    print(f"in condition, cur_targets:{cur_targets}, {Counter(cur_targets)}")

                print(f"computing {classifier_name} for layer: {idx}, "
                      f"cur_features: {cur_features.T.shape}, cur_targets: {cur_targets.shape}")

                start = time.time()
                with parallel_backend('threading'):
                    classifier_instance.fit(cur_features.T, cur_targets)
                self.sklearn_solutions[classifier_name][idx] = deepcopy(classifier_instance)
                print(f"time for computing: {classifier_name}: {time.time() - start}")

    def predict_sklearn(self, features, targets):
        print(f"starting sklearn predict")
        result = {}
        for classifier_key, cur_classifier in self.sklearn_solutions.items():
            result[classifier_key] = {}
            for layer_idx in tqdm(range(self.from_layer, self.end_layer)):
                start = time.time()
                cur_features, cur_targets = features[layer_idx], targets[layer_idx]

                print(f"cur_features:{cur_features.shape}, cur_classifier:{cur_classifier}")

                prediction = cur_classifier[layer_idx].predict(cur_features.T)
                print(f"prediction:{prediction.shape}, cur_targets:{cur_targets.shape}")
                print(f"prediction:{type(prediction)}, cur_targets:{type(cur_targets)}")

                prediction = prediction.squeeze()
                if "super" in classifier_key:
                    cur_targets = mapper.coarse_labels[cur_targets.cpu()]
                    print(f"in super, cur_targets:{cur_targets}, prediction:{prediction}")

                else:
                    cur_targets = cur_targets.numpy()

                accuracy = np.sum(prediction == cur_targets) / len(prediction)
                print(f"{classifier_key}, accuracy for layer:{layer_idx}: {accuracy}")
                print(f"time for layer:{layer_idx}:{time.time() - start}")
                result[classifier_key][layer_idx] = accuracy

        return result

    def compute_lstqsq_classifiers(self):
        print(f"beginning compute lstsq classifiers")
        for idx in tqdm(range(self.from_layer, self.end_layer)):
            cur_features, cur_targets = self.feature_bank[idx], self.targets_bank[idx]
            one_hot_labels = F.one_hot(cur_targets, num_classes=self.num_classes).float()
            print(
                f"computing lstsq classifier for layer: {idx}, cur_features: {cur_features.T.shape}, cur_targets: {one_hot_labels.shape}")
            start = time.time()
            try:
                lstsq_solution, _ = torch.lstsq(cur_features.T, one_hot_labels)
                # lstsq_solution:torch.Size([50000, 8192]), should be [100, 8192]
                # https://github.com/pytorch/pytorch/issues/56833
                lstsq_solution = lstsq_solution[:self.num_classes]
                print(f"lstsq_solution:{lstsq_solution.shape}")

            except:
                lstsq_solution = None

            print(f"total time for layer {idx}, {time.time() - start}")
            self.lin_classifiers[idx] = lstsq_solution

    def predict_lstsq(self, features, targets):
        print(f"starting predict")
        result = {'lstsq': {}}
        for layer_idx in tqdm(range(self.from_layer, self.end_layer)):
            start = time.time()
            cur_classifier = self.lin_classifiers[layer_idx]
            if cur_classifier is None:
                continue
            cur_features, cur_targets = features[layer_idx], targets[layer_idx]
            print(f"cur_features:{cur_features.shape}, cur_classifier:{cur_classifier.shape}")
            prediction = (cur_features.T @ cur_classifier.T).argmax(dim=1)
            accuracy = (prediction == cur_targets).sum() / len(cur_targets)
            print(f"lsqtsq: accuracy for layer:{layer_idx}: {accuracy}")
            print(f"time for layer:{layer_idx}:{time.time() - start}")
            result['lstsq'][layer_idx] = accuracy.item()
        return result
