import torch
from src.meta_learning.model import MetaEstimator, NormalEstimator
from src.meta_learning.utils import BestSubset, count_subsets
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import os
from typing import Optional, Type, Callable, Union
from src.util.tb_log import TBWritter2
from src.util.metric import Metric
from src.util.utils import EarlyStopping, DataLoaderSampler
from src.verify.verify import test_subset
import sys
import math
import copy
import time


class FSS(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.score = nn.Parameter(torch.zeros(n), requires_grad=True)

    def get_score(self):
        return self.score

    def get_prob(self):
        return F.softmax(self.get_score(), dim=0)

    def sample_subsets(self, k, n=1):
        _probs = self.get_prob().view(1, -1).repeat(n, 1)
        select_dims_sample = torch.multinomial(_probs, k, replacement=False).detach().cpu().numpy()
        return select_dims_sample


class FeatureSelection(object):
    def __init__(self, embedding_dims, meta_dims, k,
                 train_loader, valid_loader, test_dataset,
                 weight_step_lr=0.001, feature_step_lr=0.01,
                 weight_early_stop_patience=-1, weight_early_stop_filter=0,
                 device=None, LOG_DIR='log/', loss_fn=nn.CrossEntropyLoss,
                 use_retrain_test=True,
                 estimator_class=MetaEstimator):
        super(FeatureSelection, self).__init__()

        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
        self.device = device

        self.embedding_dims = embedding_dims
        self.metafe = estimator_class(embedding_dims, meta_dims, loss_fn).to(device)
        self.fss = FSS(embedding_dims[0])

        self.weight_step_lr = weight_step_lr
        self.weight_optim = torch.optim.Adam(self.metafe.parameters(), lr=weight_step_lr)
        self.feature_optim = torch.optim.RMSprop(self.fss.parameters(), lr=feature_step_lr)

        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.train_sampler = DataLoaderSampler(train_loader, device)
        self.eval_sampler = DataLoaderSampler(valid_loader, device)

        self.test_dataset = test_dataset

        self.use_retrain_test = use_retrain_test
        self.k = k
        self.n = embedding_dims[0]
        self.LOG_DIR = LOG_DIR
        self.log_writer = TBWritter2(LOG_DIR)

        self._feat_step_i = 0
        self._weight_step_i = 0
        self._search_step_i = 0

        if weight_early_stop_patience is None:
            self.weight_early_stop = None
        else:
            self.weight_early_stop = EarlyStopping(self.metafe, weight_early_stop_patience, weight_early_stop_filter)

        self.best_select = BestSubset()
        self.evaluate_best_select = BestSubset()

        self._init_metafe_weights = copy.deepcopy(self.metafe.state_dict())


    def reset_metafe(self):
        print('................weight reset...........')
        self.weight_early_stop.reset()
        self.metafe.load_state_dict(self._init_metafe_weights)
        self.weight_optim = torch.optim.Adam(self.metafe.parameters(), lr=self.weight_step_lr)

    def test_select(self, select, final_test_dataset=None):
        if self.test_dataset is not None:
            train, valid, test = self.test_dataset
        else:
            train = self.train_loader.dataset
            valid, test = self.valid_loader.dataset, final_test_dataset

        if final_test_dataset is not None:
            train = self.train_loader.dataset
            valid, test = self.valid_loader.dataset, final_test_dataset
        if test is None:
            test = valid

        res = test_subset(select, 2, train, valid, test,
                          hidden_layers=[self.k] + self.embedding_dims[1:], batch_size=1024,
                          max_epoch=2000, verbose=0, device=self.device)
        return res

    def search_ones(self, select=None):
        self._search_step_i += 1

        if select is None:
            select = self.fss.sample_subsets(self.k, 1)[0]

        eval_res = self.evaluate_select(select)
        self.evaluate_best_select.append(select, eval_res.f1())
        
        if self.use_retrain_test:
            res = self.test_select(select)
            self.best_select.append(select, res['f1'])
        else:
            res = eval_res
            self.best_select.append(select, eval_res.f1())

        self.log_writer.append(type='search', update=True, counter='search',
                               non_appears=self.best_select.non_appear_times,
                               best_f1=self.best_select.best_score,
                               eval_best_f1=self.evaluate_best_select.best_score)
        return self.best_select.non_appear_times, res

    def search_topk(self):
        topk = self.sorted_idx()[:self.k]

        appear_times, metric = self.search_ones(topk)
        if isinstance(metric, dict):
            self.best_select.append(topk, metric['f1'])
            self.log_writer.append(type='search', update=False, counter='search',
                                   **{'f1_topk': metric['f1'], 'loss': metric['loss']})
        else:
            self.best_select.append(topk, metric.f1())
            self.log_writer.append(type='search', update=False, counter='search',
                                   **{'f1_topk': metric.f1(), 'loss': metric.loss()})

    def search(self, max_iter, early_stop_search=10000, final_test_loader=None):
        # self.search_topk()
        for i in range(max_iter):
            appear_times, metric = self.search_ones()
            if appear_times > early_stop_search:
                break
            if self._search_step_i % 10 == 0:
                print(f'{i}/{max_iter} searching: best score: {self.best_select.best_score},  appear_times:{appear_times}')
                self.best_select.save(os.path.join(self.LOG_DIR, 'search_result.yaml'))
                self.evaluate_best_select.save(os.path.join(self.LOG_DIR, 'search_eval_result.yaml'))
                
                if final_test_loader is not None:
                    if self.best_select.true_subset_nums(self.best_select.best_subset) < 30:
                        final_test_res = self.test_select(self.best_select.best_subset, final_test_loader.dataset)
                        self.best_select.add_true_result(self.best_select.best_subset, final_test_res)
                        self.log_writer.append(type='search', update=False, counter='search',
                                               **{k+'_test_final': v for k, v in final_test_res.items()})

                self.evaluate_best_select.save(os.path.join(self.LOG_DIR, 'search_eval_result.yaml'))
                self.best_select.save(os.path.join(self.LOG_DIR, 'search_result.yaml'))
        self.evaluate_best_select.save(os.path.join(self.LOG_DIR, 'search_eval_result.yaml'))
        self.best_select.save(os.path.join(self.LOG_DIR, 'search_result.yaml'))

    def _forward_select(self, dataloader, select):
        metric = Metric()
        for _ in dataloader:
            x, y = self.eval_sampler.get()
            with torch.no_grad():
                loss, pred = self.metafe.forward(x, y, select)
            metric.update(loss, pred, y)
        return metric

    def evaluate_select(self, select):
        return self._forward_select(self.valid_loader, select)

    def evaluate(self, sample_subset_num=10, update=True):
        metric = Metric()

        for _ in range(sample_subset_num):
            x, y = self.eval_sampler.get()
            select = self.fss.sample_subsets(self.k, 1)[0]
            with torch.no_grad():
                loss, pred = self.metafe.forward(x, y, select)
            metric.update(loss, pred, y)

        if update:
            self._update_metric('evaluate', metric)
        return metric

    def feat_print(self):
        P = self.get_prob().detach().cpu().numpy()
        _sorted_idx = self.sorted_idx()
        print('feature_p:', P[_sorted_idx[self.k - 10: self.k]], '|', P[_sorted_idx[self.k: self.k + 10]])

    def feat_step(self, max_iter=1000, subsets_batch_size=256, batch_iter=1, top_k_p=0.8, no_search_iter=0):
        for i in range(max_iter):
            self._feat_step_i += 1

            selects = self.fss.sample_subsets(self.k, subsets_batch_size)
            subset_loss = torch.zeros(subsets_batch_size, device=self.device).float()
            subset_scores = torch.zeros(subsets_batch_size, device=self.device).float()
            for _s_i, select in enumerate(selects):
                subset_scores[_s_i] = torch.sum(self.fss.get_score()[select])
            subset_prob = F.softmax(subset_scores, dim=0)

            metric = Metric()
            for subset_i, select in enumerate(selects):
                for b_iter in range(batch_iter):
                    x, y = self.eval_sampler.get()

                    with torch.no_grad():
                        _loss, _pred = self.metafe.forward(x, y, select)
                    metric.update(_loss, _pred, y)
                    subset_loss[subset_i] += _loss / batch_iter

            self.feature_optim.zero_grad()
            torch.sum(subset_loss.detach() * subset_prob).backward()
            self.feature_optim.step()

            self._update_metric('feat_train', metric, f"{i}/{max_iter}", counter='feat')

            self.feat_print()

            if (self._feat_step_i + 1) % 10 == 0:

                if self._feat_step_i > no_search_iter:
                    self.search_topk()
                    self.search_ones()
                    self.evaluate_best_select.save(os.path.join(self.LOG_DIR, 'search_eval_result.yaml'))
                    self.best_select.save(os.path.join(self.LOG_DIR, 'search_result.yaml'))

                test_samples = self.fss.sample_subsets(self.k, 1000)
                differ_rate = count_subsets(test_samples) / 1000
                top_k_prob = torch.sum(self.fss.get_prob().sort(descending=True)[0][:self.k])

                self.log_writer.append('feat_train', update=False, counter='feat',
                                       differ_rate=differ_rate, top_k_prob=top_k_prob, feat_step_i=self._feat_step_i)

                if top_k_prob >= top_k_p:
                    self.evaluate_best_select.save(os.path.join(self.LOG_DIR, 'search_eval_result.yaml'))
                    self.best_select.save(os.path.join(self.LOG_DIR, 'search_result.yaml'))
                    return True

        self.best_select.save(os.path.join(self.LOG_DIR, 'search_result.yaml'))
        self.evaluate_best_select.save(os.path.join(self.LOG_DIR, 'search_eval_result.yaml'))
        return False

    def weight_step(self, max_iter=1000, subset_iter=34, check_period=1, keep_best=False):
        check_period = max(1, int(check_period * self.train_sampler.loader_length / subset_iter))
        for i in range(max_iter):
            self._weight_step_i += 1

            metric = Metric()
            self.weight_optim.zero_grad()
            for j in range(subset_iter):
                x, y = self.train_sampler.get()
                selected = self.fss.sample_subsets(self.k, 1)[0]
                loss, pred = self.metafe.forward(x, y, selected)
                # loss = loss.mean()
                loss.backward()
                metric.update(loss, pred, y)
            self.weight_optim.step()
            self._update_metric('weight', metric, f'{i}/{max_iter}', show=True, counter='weight')

            if self._weight_step_i % check_period == 0:
                metric = self.evaluate(subset_iter)
                self.weight_early_stop.update(metric.loss())
                if self.weight_early_stop is not None and self.weight_early_stop.is_stop:
                    print('early stopping...')
                    self.metafe.load_state_dict(self.weight_early_stop.best_weight)
                    self.weight_early_stop.reset()
                    break

        # end iter & load best
        if keep_best and self.weight_early_stop.best_weight is not None:
            self.metafe.load_state_dict(self.weight_early_stop.best_weight)

    def save(self, name, dir=None):
        if dir is None:
            dir = self.LOG_DIR
        with open(os.path.join(dir, 'metafe_' + name), 'wb') as f:
            torch.save(self.metafe.state_dict(), f)
        with open(os.path.join(dir, 'fss_' + name), 'wb') as f:
            torch.save(self.fss.state_dict(), f)

    def load(self, name, dir=None):
        if dir is None:
            dir = self.LOG_DIR
        path = os.path.join(dir, 'metafe_' + name)
        res = [False, False]
        if os.path.exists(path):
            with open(path, 'rb') as f:
                self.metafe.load_state_dict(torch.load(f))
            res[0] = True
        path = os.path.join(dir, 'fss_' + name)
        if os.path.exists(path):
            with open(path, 'rb') as f:
                self.fss.load_state_dict(torch.load(f))
            res[1] = True
        return all(res)

    def _update_metric(self, type, metric, prefix='', show=True, counter='train'):
        self.log_writer.append(type=type, counter=counter, loss=metric.loss(), precision=metric.precision(),
                               f1=metric.f1(), recall=metric.recall())
        if show:
            print(f'{type}: {prefix}  '
                  f'Loss: {metric.loss():.4f}, Acc: {metric.precision():.4f}'
                  f'F1: {metric.f1():.4f}, recall: {metric.recall()}')
        sys.stdout.flush()

    def sorted_idx(self):
        prob = self.fss.get_prob().detach().cpu().numpy()
        sorted_prob = sorted(zip(prob, np.arange(len(prob))), key=lambda x: x[0], reverse=True)
        _sorted_idx = [idx for prob, idx in sorted_prob]
        return _sorted_idx

    def select(self):
        return self.best_select.best_subset

    def get_prob(self):
        return self.fss.get_prob()

    def get_score(self):
        return self.fss.get_score()
