import logging
from typing import Dict

import torch
from torch import Tensor, optim
from bbo.algorithms.basic_bo.bo import BO
from bbo.algorithms.dropout.dropout import Dropout
from bbo.algorithms.dropout.dropout import select_active_dim
from bbo.utils import print_dict

log = logging.getLogger(__name__)


class PretrainDropoutKumar(Dropout):
    def __init__(
        self,
        dim: int,
        lb: Tensor,
        ub: Tensor,
        active_dim: int,
        pretrain_file: str,
        name: str = 'PretrainDropout',
        n_init: int = 10,
        q: int = 1,
        inner_algo: str = 'BO',
        unimportant_strategy: str = 'bestk',
        k: int = 10,
        finetune_cfg: Dict = None,
        **inner_config,
    ):
        assert lb.ndim == 1 and ub.ndim == 1
        assert lb.shape == ub.shape
        assert (lb < ub).all()
        assert inner_algo in ['BO']
        assert unimportant_strategy in ['random', 'bestk']
        self.dim = dim
        self.lb = lb
        self.ub = ub
        if active_dim > dim:
            log.warning('Active dim {} > dim {}. Set active_dim to {}'.format(active_dim, dim, dim))
            active_dim = dim
        self.active_dim = active_dim
        self.name = name
        self.n_init = n_init
        self.q = q 
        self.inner_algo = inner_algo
        self.k = k
        self.unimportant_strategy = self.create_unimportant_strategy(unimportant_strategy)
        self.pretrain_file = pretrain_file
        self.finetune_cfg = finetune_cfg
        self.inner_config = inner_config

        self.X = torch.zeros((0, dim))
        self.Y = torch.zeros((0, 1))

    def create_inner_algo(self, dim, lb, ub):
        if self.inner_algo == 'BO':
            algo = BO(dim, lb, ub, name='PretrainDropout-BO', **self.inner_config)
        else:
            raise NotImplementedError
        return algo

    def finetune(self, mll, model, train_X, train_Y):
        # finetune some parameters
        finetune_param_name = ['covar_module.wrapper.alpha', 'covar_module.wrapper.beta']
        for name, param in model.named_parameters():
            if name not in finetune_param_name:
                param.requires_grad = False

        optimizer = optim.Adam(model.covar_module.wrapper.parameters(), lr=self.finetune_cfg['lr'])
        model.train()
        model.likelihood.train()
        for _ in range(self.finetune_cfg['epochs']):
            optimizer.zero_grad()
            output = model(train_X)
            loss = - mll(output, train_Y.reshape(-1))
            loss.backward()
            optimizer.step()
        model.eval()
        model.likelihood.eval()

    def ask(self) -> Tensor:
        if len(self.X) == 0:
            next_X = self.init()
        else:
            idx = select_active_dim(self.dim, self.active_dim)

            # init the inner algorithm
            select_lb, select_ub = self.lb[idx], self.ub[idx]
            algo = self.create_inner_algo(self.active_dim, select_lb, select_ub)
            train_X = self.X[:, idx]
            train_Y = self.Y
            algo.tell(train_X, train_Y)

            # load pretrain model and optimize
            train_X, train_Y = algo.preprocess()
            mll, model = algo.create_model(train_X, train_Y)

            # load pretrain model
            path = 'saved_models/{}.pth'.format(self.pretrain_file)
            state_dict = torch.load(path)
            state_dict['covar_module.base_kernel.base_kernel.raw_lengthscale'] = \
                state_dict['covar_module.base_kernel.base_kernel.raw_lengthscale'][:, idx]
            model.load_state_dict(state_dict, strict=False)
            log.info('Load from {}'.format(path))

            # finetune
            self.finetune(mll, model, train_X, train_Y)

            # optimize acquisition function
            AF = algo.create_acqf(model, train_X, train_Y)
            important_X = algo.optimize_acqf(AF)
            important_X = algo.postprocess(important_X)

            # fill
            next_X = self.fill(idx, important_X)

        return next_X



