import os
import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import cross_entropy, softmax, kl_div, log_softmax
from torch.utils.data import DataLoader
from tqdm import trange
from .base import CMEBase
from time import time

class LinearProxy(CMEBase):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)
        self.W = nn.Linear(cfg['input_dim'], cfg['n_classes'], bias=False).to(self.device)
        self.learnware = 'learnware' in kwargs

    def generate_helper(self, *args, **kwargs):
        start = time()
        model = kwargs.get('learnware', None)
        if model:
            model.model.eval()
        self.W.train()
        dataloader = DataLoader(list(zip(self.X, self.Y)), batch_size=self.cfg['batch_size'], shuffle=True)
        optimizer = optim.Adam(self.W.parameters(), lr=self.cfg['lr'])
        for _ in trange(self.cfg['steps']):
            for features, labels in dataloader:
                features = features.to(self.device)
                labels = labels.to(self.device)
                outputs = self.W(features)
                ce_loss = cross_entropy(outputs, labels)
                kl_loss = 0

                if model:
                    with torch.no_grad():
                        logits = model(features)
                        probs = softmax(logits, dim=-1)
                    log_probs = log_softmax(outputs, dim=-1)
                    kl_loss = kl_div(log_probs, probs, reduction='batchmean')

                loss = ce_loss + kl_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        self.W.eval()
        print('Training time:', time() - start)

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            W=self.get_W()
        )

    def load_helper(self):
        W = np.load(self.path, allow_pickle=True)['W']
        self.W.weight = nn.Parameter(torch.from_numpy(W).to(self.device))
        self.W.requires_grad_(False)
        self.W.to(self.device)

    def get_W(self):
        return self.W.weight.detach().cpu().numpy()

    def compare(self, other):
        W1 = self.get_W()
        W2 = other.get_W()
        W = W1 @ W2.T    # k1 * k2
        rows, cols = linear_sum_assignment(-W)
        max_matching = W[rows, cols].sum()
        return 1 / max_matching