import torch
from src.meta_learning.model import MetaEstimator, NormalEstimator
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import os
from typing import Optional, Type, Callable, Union
from src.util.tb_log import TBWritter
from src.util.metric import Metric
from src.util.utils import EarlyStopping, DataLoaderSampler
import sys
import math


class MetaFS(nn.Module):
    def __init__(self, dims, meta_dims, loss_fn=nn.CrossEntropyLoss, estimator_class=MetaEstimator):
        super(MetaFS, self).__init__()
        print('create:', estimator_class)
        self.estimator = estimator_class(dims, meta_dims, loss_fn)
        self.score = nn.Parameter(torch.zeros(dims[0]), requires_grad=True)

    def forward(self, x, y, select):
        _x = torch.zeros_like(x)
        _x[:, select] = x[:, select]

        _loss, _pred = self.estimator(_x, y, select)
        return torch.mean(_loss), _pred

    def get_score(self):
        return self.score

    def get_prob(self):
        return F.softmax(self.get_score(), dim=0)

    def feat_params(self):
        yield self.score

    def weight_params(self):
        for p in self.estimator.parameters():
            yield p

    def sample_subsets(self, k, n=1):
        _probs = self.get_prob().view(1, -1).repeat(n, 1)
        select_dims_sample = torch.multinomial(_probs, k, replacement=False).detach().cpu().numpy()
        return select_dims_sample


class FeatureSelection(object):
    def __init__(self, embedding_dims, meta_dims, k,
                 train_loader, valid_loader,
                 weight_step_lr=0.001, feature_step_lr=0.01,
                 weight_early_stop_patience=-1, feature_early_stop_patience=-1,
                 weight_early_stop_filter=0, feature_early_stop_filter=0,
                 device=None, LOG_DIR='log/', loss_fn=nn.CrossEntropyLoss, is_callback=True,
                 estimator_class=MetaEstimator):
        super(FeatureSelection, self).__init__()

        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
        self.device = device

        self.model = MetaFS(embedding_dims, meta_dims, loss_fn, estimator_class=estimator_class).to(device)

        self.weight_optim = torch.optim.Adam(self.model.weight_params(), lr=weight_step_lr)
        self.feature_optim = torch.optim.RMSprop(self.model.feat_params(), lr=feature_step_lr)

        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.train_sampler = DataLoaderSampler(train_loader, device)
        self.eval_sampler = DataLoaderSampler(valid_loader, device)

        self.k = k
        self.n = embedding_dims[0]
        self.LOG_DIR = LOG_DIR
        self.log_writer = TBWritter(LOG_DIR)
        self.is_callback = is_callback

        self._feat_step_i = 0
        self._weight_step_i = 0

        if weight_early_stop_patience is None:
            self.weight_early_stop = None
        else:
            self.weight_early_stop = EarlyStopping(self.model, weight_early_stop_patience, weight_early_stop_filter)

        if feature_early_stop_patience is None:
            self.feature_early_stop = None
        else:
            self.feature_early_stop = EarlyStopping(self.model, feature_early_stop_patience, feature_early_stop_filter)

        self._best_selected = None
        self._best_f1 = 0

    def evaluate(self, select=None, update=True):
        metric = Metric()
        for _ in range(max(1, self.eval_sampler.loader_length)):
            x, y = self.eval_sampler.get()
            if select is None:
                select = self.model.sample_subsets(self.k, 1)[0]
            with torch.no_grad():
                loss, pred = self.model.forward(x, y, select)
            metric.update(loss, pred, y)

            if metric.f1() >= self._best_f1:
                self._best_f1 = metric.f1()
                self._best_selected = select

        if update:
            self._update_metric('weight_evaluate', metric)
        return metric

    def feat_evaluate(self):
        P = self.get_prob().detach().cpu().numpy()
        _sorted_idx = self.sorted_idx()
        print('feature_p:', P[_sorted_idx[self.k - 10: self.k]], '|', P[_sorted_idx[self.k: self.k + 10]])
        return self.evaluate(self.select(self.k), update=False)

    def feat_step(self, max_iter=10000, subsets_batch_size=256, batch_iter=1, callback=lambda model, i: {},
                  keep_best=False):
        for i in range(max_iter):
            self._feat_step_i += 1

            selects = self.model.sample_subsets(self.k, subsets_batch_size)
            subset_loss = torch.zeros(subsets_batch_size, device=self.device).float()
            subset_scores = torch.zeros(subsets_batch_size, device=self.device).float()
            for _s_i, select in enumerate(selects):
                subset_scores[_s_i] = torch.sum(self.model.get_score()[select])
            subset_prob = F.softmax(subset_scores, dim=0)

            metric = Metric()
            for subset_i, select in enumerate(selects):
                for b_iter in range(batch_iter):
                    x, y = self.eval_sampler.get()

                    with torch.no_grad():
                        _loss, _pred = self.model.forward(x, y, select)
                    metric.update(_loss, _pred, y)
                    subset_loss[subset_i] += _loss / batch_iter

            self.feature_optim.zero_grad()
            torch.sum(subset_loss.detach() * subset_prob).backward()
            self.feature_optim.step()

            self._update_metric('feat', metric, i, max_iter)

            if self.is_callback:
                result = callback(self, self._feat_step_i)
                self.log_writer.append('feat_callback', **result)
                if 'f1' in result:
                    _es_update = -1 * result['f1']
                elif 'loss' in result:
                    _es_update = result['loss']
                else:
                    _es_update = None
            else:
                metric = self.feat_evaluate()
                self._update_metric('feat_evaluate', metric, i, max_iter)
                #                 _es_update = -1 * metric.f1()
                _es_update = metric.loss()

            if _es_update is not None:
                self.feature_early_stop.update(_es_update)
            if self.feature_early_stop.is_stop:
                print('early stoping...')
                self.model.load_state_dict(self.feature_early_stop.best_weight)
                self.feature_early_stop.reset()
                break
            sys.stdout.flush()

        if keep_best and self.feature_early_stop.best_weight is not None:
            self.model.load_state_dict(self.feature_early_stop.best_weight)

    def weight_step(self, max_iter=10000, subset_iter=1, check_period=1, keep_best=False):
        check_period = max(1, int(check_period * self.train_sampler.loader_length / subset_iter))
        for i in range(max_iter):
            self._weight_step_i += 1

            self.weight_optim.zero_grad()
            for j in range(subset_iter):
                x, y = self.train_sampler.get()

                selected = self.model.sample_subsets(self.k, 1)[0]
                loss, pred = self.model.forward(x, y, selected)
                self._create_update_metric('weight', loss, pred, y, i=i, max_i=max_iter, show=(j == 0))
                loss.backward()
            self.weight_optim.step()

            if self._weight_step_i % check_period == 0:
                metric = self.evaluate()
                self.weight_early_stop.update(metric.loss())
                if self.weight_early_stop.is_stop:
                    print('early stopping...')
                    self.model.load_state_dict(self.weight_early_stop.best_weight)
                    self.weight_early_stop.reset()
                    break
                self.save('weight.tmp')
            if i > max_iter:
                break
        # end iter & load best
        if keep_best and self.weight_early_stop.best_weight is not None:
            self.model.load_state_dict(self.weight_early_stop.best_weight)

    def save(self, name):
        with open(os.path.join(self.LOG_DIR, name), 'wb') as f:
            torch.save(self.model.state_dict(), f)

    def load(self, name):
        path = os.path.join(self.LOG_DIR, name)
        if os.path.exists(path):
            with open(path, 'rb') as f:
                self.model.load_state_dict(torch.load(f))
            return True
        return False

    def _update_metric(self, type, metric, i=0, max_i=0, show=True):
        self.log_writer.append(type=type, loss=metric.loss(), precision=metric.precision(),
                               f1=metric.f1(), recall=metric.recall())
        if show:
            print(f'{type}: {i}/{max_i}  '
                  f'Loss: {metric.loss():.4f}, Acc: {metric.precision():.4f}'
                  f'F1: {metric.f1():.4f}, recall: {metric.recall()}')
        sys.stdout.flush()

    def _create_update_metric(self, type, loss, pred, y_true, i=0, max_i=0, show=True):
        metric = Metric()
        metric.update(loss, pred, y_true)
        self._update_metric(type, metric, i, max_i, show)

    def sorted_idx(self):
        prob = self.model.get_prob().detach().cpu().numpy()
        sorted_prob = sorted(zip(prob, np.arange(len(prob))), key=lambda x: x[0], reverse=True)
        _sorted_idx = [idx for prob, idx in sorted_prob]
        return _sorted_idx

    def select(self, n):
        return self.sorted_idx()[:n]

    def get_prob(self):
        return self.model.get_prob()

    def get_score(self):
        return self.model.get_score()
