import torch
from src.meta_learning.model import MetaFE
from src.meta_learning.utils import BestSubsets, count_subsets2
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.functional as F
import os
from typing import Optional, Type, Callable, Union
from src.util.tb_log import TBWritter2
from trainer.utils import DatasetShell, SubShell
from src.util.metric import Metric
from src.util.utils import EarlyStopping, DataLoaderSampler, change_data
from src.verify.verify import test_subset
import sys
import math
import copy
import time


class FSS(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.score = nn.Parameter(torch.zeros(n), requires_grad=True)

    def get_score(self):
        return self.score

    def get_prob(self):
        return F.softmax(self.get_score(), dim=0)

    def sample_subsets(self, k, n=1):
        _probs = self.get_prob().view(1, -1).repeat(n, 1)
        select_dims_sample = torch.multinomial(_probs, k, replacement=False).detach().cpu().numpy()
        return select_dims_sample


def check_dataset(dataset, max_batch_size: int):
    dataset_length = len(dataset)
    if dataset_length < max_batch_size:
        cat_num = int(np.ceil(max_batch_size / dataset_length))
        new_dataset = SubShell(*[dataset for _ in range(cat_num)])
        return new_dataset
    return dataset


class FeatureSelection(object):
    def __init__(self, metafe: nn.Module, k, n,
                 train_dataset, valid_dataset,
                 hold_subsets_num=100,
                 sample_batch_size=128,
                 max_batch_size=4096,
                 weight_update_batch_num=32,
                 feat_sample_batch_size=128,
                 feat_update_batch_num=256,
                 weight_step_lr=0.001, feature_step_lr=0.01,
                 device=None, LOG_DIR='log/', loss_fn=nn.CrossEntropyLoss):

        if device is None:
            if torch.cuda.is_available():
                device = torch.device('cuda')
            else:
                device = torch.device('cpu')
        self.device = device

        self.metafe = metafe
        self.fss = FSS(n)
        self.loss_fn = loss_fn()

        self.k = k
        self.n = n

        self.weight_step_lr = weight_step_lr
        self.weight_optim = torch.optim.Adam(self.metafe.parameters(), lr=weight_step_lr)
        self.feature_optim = torch.optim.RMSprop(self.fss.parameters(), lr=feature_step_lr)

        self.sample_batch_size = sample_batch_size
        self.max_batch_size = (max_batch_size // sample_batch_size) * sample_batch_size
        self.weight_update_batch_num = weight_update_batch_num
        self.feat_update_batch_num = feat_update_batch_num
        self.feat_sample_batch_size = feat_sample_batch_size

        self.train_loader = DataLoader(check_dataset(train_dataset, max_batch_size), batch_size=max_batch_size,
                                       shuffle=True, drop_last=True)
        self.valid_loader = DataLoader(check_dataset(valid_dataset, max_batch_size), batch_size=max_batch_size,
                                       shuffle=True, drop_last=True)
        self.train_sampler = DataLoaderSampler(self.train_loader, device)
        self.eval_sampler = DataLoaderSampler(self.valid_loader, device)

        self.LOG_DIR = LOG_DIR
        self.log_writer = TBWritter2(LOG_DIR)

        self.best_select = BestSubsets(max_len=hold_subsets_num, save_path=os.path.join(LOG_DIR, "best_selects.yaml"))

        self._feat_step_i = 0

    def _weight_train(self):
        self.metafe.train()

        metric = Metric()
        num_subsets = self.max_batch_size // self.sample_batch_size
        self.weight_optim.zero_grad()
        for _ in range(int(np.ceil(self.weight_update_batch_num / num_subsets))):
            x, y = self.train_sampler.get()
            selected = self.fss.sample_subsets(self.k, num_subsets)
            # print(x.shape, len(selected), self.sample_batch_size)
            y_ = self.metafe.forward_batch(x, selected, self.sample_batch_size)
            loss = self.loss_fn(y_, y)
            loss.backward()
            metric.update(loss, y_, y)
        self.weight_optim.step()

        return metric

    def _feat_train(self):
        self.metafe.eval()
        metric = Metric()
        num_subsets = self.max_batch_size // self.feat_sample_batch_size

        selects = self.fss.sample_subsets(self.k, self.feat_update_batch_num)  # [k * self.feat_update_batch_num]
        subset_loss = torch.zeros(self.feat_update_batch_num, device=self.device).float()
        subset_scores = torch.zeros(self.feat_update_batch_num, device=self.device).float()
        for _s_i, select in enumerate(selects):
            subset_scores[_s_i] = torch.sum(self.fss.get_score()[select])
        subset_prob = F.softmax(subset_scores, dim=0)

        for i in range(self.feat_update_batch_num // num_subsets):
            xs, ys = self.eval_sampler.get()
            selects_batch = selects[num_subsets * i: num_subsets * (i + 1)]
            with torch.no_grad():
                ys_ = self.metafe.forward_batch(xs, selects_batch, self.feat_sample_batch_size)
                for j, (select, y_, y) in enumerate(
                        zip(selects_batch, torch.chunk(ys_, num_subsets), torch.chunk(ys, num_subsets))):
                    loss = self.loss_fn(y_, y)
                    subset_loss[num_subsets * i + j] += loss
                    metric.update(loss, y_, y)

        self.feature_optim.zero_grad()
        torch.sum(subset_loss.detach() * subset_prob).backward()
        self.feature_optim.step()

        return metric

    def _evaluate(self, sample_subset_num=10, update=False):
        self.metafe.eval()

        metric = Metric()
        num_subsets = self.max_batch_size // self.sample_batch_size
        for _ in range(int(np.ceil(sample_subset_num / num_subsets))):
            x, y = self.eval_sampler.get()
            select = self.fss.sample_subsets(self.k, num_subsets)
            with torch.no_grad():
                y_ = self.metafe.forward_batch(x, select, self.sample_batch_size)
                loss = self.loss_fn(y_, y)
                metric.update(loss, y_, y)
        if update:
            self._update_metric('evaluate', metric)
        return metric

    def evaluate(self, select, max_iter=None):
        self.metafe.eval()

        select = np.array(select)
        metric = Metric()

        input_max_iter = max_iter
        data_max_iter = int(np.ceil(len(self.valid_loader.dataset) / self.max_batch_size))
        max_iter = data_max_iter if input_max_iter is None else min(input_max_iter, data_max_iter)

        for _ in range(max_iter):
            x, y = self.eval_sampler.get()
            with torch.no_grad():
                y_ = self.metafe.forward(x, select)
                loss = self.loss_fn(y_, y)
                metric.update(loss, y_, y)
        return metric

    def weight_step(self, max_iter=1000, check_period=50, weight_early_stop_patience=200, weight_early_stop_filter=0.6,
                    load_best=True, prefix=''):
        weight_early_stop = EarlyStopping(self.metafe, weight_early_stop_patience, weight_early_stop_filter)

        for i in range(max_iter):
            _t = time.time()

            metric = self._weight_train()

            self._update_metric('weight', metric, prefix + f'{i}/{max_iter}  t:{time.time() - _t:.2f}', show=True, counter='weight')

            if (i + 1) % check_period == 0 and (load_best or weight_early_stop_patience >= 0):
                metric = self._evaluate(update=True)
                weight_early_stop.update(metric.loss())
                if weight_early_stop.is_stop:
                    print('early stopping...')
                    self.metafe.load_state_dict(weight_early_stop.best_weight)
                    break

        if load_best:
            self.metafe.load_state_dict(weight_early_stop.best_weight)

    def feat_step(self, max_iter=1000, top_k_p=0.8, search_max_iter=None, prefix='', search=True):
        for i in range(max_iter):
            self._feat_step_i += 1

            _t = time.time()
            metric = self._feat_train()
            self._update_metric('feat_train', metric, prefix + f"{i}/{max_iter}  t:{time.time() - _t:.2f}", counter='feat')

            if (self._feat_step_i + 1) % 10 == 0:

                if search:
                    self.search_topk(max_iter=search_max_iter)
                    self.search_ones(max_iter=search_max_iter)

                subset_appear_prob = count_subsets2(self.fss.sample_subsets(self.k, 1000))
                top_k_prob = torch.sum(self.fss.get_prob().sort(descending=True)[0][:self.k])

                print(f'feat_train: top_k_prob: {top_k_prob}  subset_appear_prob: {subset_appear_prob}')

                self.log_writer.append('feat_train', update=False, counter='feat',
                                       subset_appear_prob=subset_appear_prob, top_k_prob=top_k_prob,
                                       feat_step_i=self._feat_step_i)

                if top_k_prob >= top_k_p:
                    self.save_results()
                    return True

            if (self._feat_step_i + 1) % 100 == 0:
                self.save_results()

        self.save_results()
        return False

    def search_step(self, max_iter=1000, search_max_iter=None, prefix=''):
        for i in range(max_iter):
            _t = time.time()
            metric = self.search_ones(max_iter=search_max_iter)
            if (i + 1) % 100 == 0:
                print(f'searching: {prefix} {i}/{max_iter}  t:{time.time() - _t:.2f}  best_select_score: {self.best_select.best_score}')
                self.save_results()
        self.save_results()

    def search_ones(self, select=None, max_iter=None):
        if select is None:
            select = self.fss.sample_subsets(self.k, 1)[0]
        metric = self.evaluate(select, max_iter)
        self.best_select.append(select, metric.f1())

        self.log_writer.append(type='search', update=True, counter='search',
                               best_f1=self.best_select.best_score,
                               **metric.dict())
        return metric

    def save_results(self):
        self.best_select.save()

    def save(self, name):
        self.save_results()

        save_dir = self.LOG_DIR
        with open(os.path.join(save_dir, 'metafe_' + name), 'wb') as f:
            torch.save(self.metafe.state_dict(), f)
        with open(os.path.join(save_dir, 'fss_' + name), 'wb') as f:
            torch.save(self.fss.state_dict(), f)

    def load(self, name, load_dir=None):
        if load_dir is None:
            load_dir = self.LOG_DIR

        self.best_select.load(os.path.join(load_dir, 'best_selects.yaml'))

        path = os.path.join(load_dir, 'metafe_' + name)
        res = [False, False]
        if os.path.exists(path):
            with open(path, 'rb') as f:
                self.metafe.load_state_dict(torch.load(f))
            res[0] = True
        path = os.path.join(load_dir, 'fss_' + name)
        if os.path.exists(path):
            with open(path, 'rb') as f:
                self.fss.load_state_dict(torch.load(f))
            res[1] = True
        return all(res)

    def search_topk(self, max_iter=None):
        topk = self.sorted_idx()[:self.k]
        if not self.best_select.is_content(topk):
            self.search_ones(topk, max_iter)

    def _update_metric(self, type, metric, prefix='', show=True, counter='train'):
        self.log_writer.append(type=type, counter=counter, **metric.dict())
        if show:
            print(f'{type}: {prefix}  ' + ', '.join([f'{k}: {v:.4f}' for k, v in metric.dict().items()]))
        sys.stdout.flush()

    def sorted_idx(self):
        prob = self.fss.get_prob().detach().cpu().numpy()
        return np.flip(prob.argsort())


    # def get_prob(self):
    #     return self.fss.get_prob()
    #
    # def get_score(self):
    #     return self.fss.get_score()
