import os
import torch
import numpy as np
from tqdm import tqdm
from .learnware import Learnware
from user import User

class Market:
    def __init__(self, cfg):
        self.cfg = cfg
        self.learnwares = []
        self.n_learnware = 0

    def submit_learnware(self, learnware_id):
        print(f'Submitting learnware {learnware_id} to market...', end='\r')
        learnware = Learnware(self.cfg, learnware_id)
        self.learnwares.append(learnware)
        self.n_learnware += 1

    def evaluate(self, user: User):
        task = self.cfg['task']
        perfs_path = os.path.join('logs', 'model_perfs', task, f'{user.user_id}.npy')
        preds_path = os.path.join('logs', 'model_preds',       f'{user.user_id}.pt')
        probs_path = os.path.join('logs', 'model_probs',       f'{user.user_id}.pt')

        if os.path.exists(perfs_path):
            if task == 'regression':
                return np.load(perfs_path), [], []
            elif os.path.exists(preds_path) and os.path.exists(probs_path):
                return np.load(perfs_path), torch.load(preds_path), torch.load(probs_path)

        testloader = user.dataset.get_loader('test')
        model_perfs = []
        model_preds = []
        model_probs = []
        for learnware in tqdm(self.learnwares):
            prefs, preds, probs = learnware.evaluate(testloader, pred=True, prob=True)
            model_perfs.append(prefs)
            model_preds.append(preds)
            model_probs.append(probs)
        model_perfs = np.array(model_perfs)
        os.makedirs(os.path.dirname(perfs_path), exist_ok=True)
        np.save(perfs_path, model_perfs)

        if task == 'classification':
            model_preds = torch.stack(model_preds)
            model_probs = torch.stack(model_probs)
            os.makedirs(os.path.dirname(preds_path), exist_ok=True)
            os.makedirs(os.path.dirname(probs_path), exist_ok=True)
            torch.save(model_preds, preds_path)
            torch.save(model_probs, probs_path)

        return model_perfs, model_preds, model_probs

    def recommend(self, user: User):
        return self.__learnware_dists(user).argsort()

    def __learnware_dists(self, user: User):
        spec = self.cfg['specification']
        dist_path = os.path.join('logs', 'spec_dists', self.cfg['task'], spec, f'{user.user_id}.npy')
        os.makedirs(os.path.dirname(dist_path), exist_ok=True)
        if spec == 'ClassWiseRKME':
            return self.__cRKME_dists(user, dist_path)
        else:
            if os.path.exists(dist_path):
                print('load from', dist_path)
                return np.load(dist_path)
            spec_dists = np.array([user.spec.compare(learnware.spec) for learnware in self.learnwares])
            np.save(dist_path, spec_dists)
            return spec_dists

    def __cRKME_dists(self, user: User, path: str):
        if os.path.exists(path):
            return np.load(path)

        lambd = self.cfg['similarity_lambda']
        user_classes = list(user.spec.classes)
        y = user.spec.y
        counts = np.array([(y == cls).sum().item() for cls in user_classes], dtype=float)
        if counts.sum() > 0:
            user_class_weight = counts / counts.sum()
        else:
            user_class_weight = np.ones(len(user_classes)) / len(user_classes)

        match_scores = np.zeros(self.n_learnware)

        for i, learnware in enumerate(tqdm(self.learnwares)):
            learnware_classes = list(learnware.spec.classes)
            if len(user_classes) == 0 or len(learnware_classes) == 0:
                match_scores[i] = -np.inf
                continue

            weight_matrix = np.zeros((len(user_classes), len(learnware_classes)), dtype=float)
            for j, user_class in enumerate(user_classes):
                for k, learnware_class in enumerate(learnware_classes):
                    val = user.spec.class_distance(learnware.spec, user_class, learnware_class, lambd=lambd)
                    weight_matrix[j, k] = float(val)

                best_per_user_class = weight_matrix.max(axis=1)
                score = np.dot(best_per_user_class, user_class_weight)
                match_scores[i] = -score

        np.save(path, match_scores)
        return match_scores