from type.sampler import Sampler
from data.bodata import BoData
from typing import List, Dict, Optional
import torch, os
import numpy as np
from loguru import logger
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms.outcome import Standardize
from botorch.acquisition import (
    qLogExpectedImprovement,
    ExpectedImprovement,
    UpperConfidenceBound,
    qExpectedImprovement,
    qUpperConfidenceBound,
)
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from botorch import fit_gpytorch_mll
from utils.acq import optimize_acqf_discrete_idx, optimize_acqf_discrete_weighted_idx


class BOSampler(Sampler):
    def __init__(self, dataset: BoData, init_x=None, init_y=None):
        super().__init__(dataset, init_x, init_y)
        if init_x is not None:

            self.train_x = init_x
            self.train_y = init_y
            self.init_model(self.train_x, self.train_y)
        else:
            self.train_x = torch.Tensor([])
            self.train_y = torch.Tensor([])

    def update_train(self, idx: List[int], obj: torch.Tensor) -> None:
        feats = torch.vstack([self.dataset[idx].feat for idx in idx])
        self.train_x = torch.cat([self.train_x, feats])
        self.train_y = torch.cat([self.train_y, obj])

    def acq_score(self, sub_space_idx: List[int]) -> torch.Tensor:
        acq = qLogExpectedImprovement(self.model, best_f=self.train_y.max())
        search_space = torch.vstack([self.dataset[idx].feat for idx in sub_space_idx])
        search_space = search_space.view(
            search_space.shape[0], 1, search_space.shape[-1]
        )
        acq_score = acq(search_space)
        return acq_score

    def pseudo_label_sample(
        self, sub_space_idx: List[int], n_sample: int, weight: torch.Tensor = None
    ) -> List[int]:
        
        USE_WEIGHT_RANDOM_SAMPLE = os.getenv("USE_WEIGHT_RANDOM_SAMPLE", "NO")
        
        factor = min(0.5, 1 - self.train_y.max() / 100)
        # factor = min(0.3, 1 - self.train_y.max() / 100)
        # if self.train_y.max() < 50:
        #     factor = 0
        predict_value = torch.vstack(
            [self.dataset[idx].predict_value for idx in sub_space_idx]
        )
        search_space = torch.vstack([self.dataset[idx].feat for idx in sub_space_idx])
        #####################################
        
        
        if USE_WEIGHT_RANDOM_SAMPLE == "YES":
            logger.warning("USE_WEIGHT_RANDOM_SAMPLE")
            
            # 反向权重采样：高产率更可能被丢弃
            yields = predict_value.detach().numpy()
            inv_weights = 1.0 / (yields - yields.min() + 1e-6)  # 反向
            inv_weights = inv_weights.flatten()
            inv_weights /= inv_weights.sum()

            keep_size = int(inv_weights.shape[0] * (1 - factor))
            if keep_size <= 0:
                keep_size = 1
            keep_indices = np.random.choice(
                inv_weights.shape[0], size=keep_size, replace=False, p=inv_weights
            )

            # 按原始顺序返回保留的样本
            temp_predict_value = [predict_value[i] for i in sorted(keep_indices)]
            temp_search_space = [search_space[i] for i in sorted(keep_indices)]
            pseudo_point = [ (temp_search_space[idx], v) for idx, v in enumerate(temp_predict_value)]
        
        else:
            pseudo_point = zip(search_space, predict_value)
            pseudo_point = sorted(pseudo_point, key=lambda x: x[1],reverse=True)
            pseudo_point = pseudo_point[int(len(pseudo_point) * factor) :]
        
        #random_pseudo_point = torch.randperm(len(pseudo_point))
        pseudo_x = torch.vstack([x[0] for x in pseudo_point])
        pseudo_y = torch.vstack([x[1] for x in pseudo_point])
        # pseudo_x = torch.vstack([pseudo_point[i][0] for i in random_pseudo_point])
        # pseudo_y = torch.vstack([pseudo_point[i][1] for i in random_pseudo_point])
        pseudo_x = torch.concat([pseudo_x, self.train_x])
        pseudo_y = torch.concat([pseudo_y, self.train_y])
        self.pseudo_init_model(pseudo_x, pseudo_y, self.model.state_dict())
        acq = qLogExpectedImprovement(self.model, best_f=self.train_y.max())
        # acq = qUpperConfidenceBound()     # tofix
        #####################################
        if weight is not None:
            candidates = optimize_acqf_discrete_weighted_idx(
                acq, q=n_sample, choices=search_space, weights=weight
            ).tolist()
        else:
            candidates = optimize_acqf_discrete_idx(
                acq, q=n_sample, choices=search_space
            ).tolist()
        candidates = (
            [sub_space_idx[idx] for idx in candidates]
            if isinstance(candidates, list)
            else [sub_space_idx[candidates]]
        )
        feats = torch.vstack([self.dataset[idx].feat for idx in candidates])
        self.train_x = torch.cat([self.train_x, feats])
        obj = torch.vstack([self.dataset[idx]._observed_value for idx in candidates])
        self.train_y = torch.cat([self.train_y, obj])
        self.init_model(self.train_x, self.train_y, self.model.state_dict())
        return candidates
    
    
    def real_exp_pseudo_label_sample(
        self, sub_space_idx: List[int], n_sample: int, cur_time:int, weight: torch.Tensor = None,
        eta = 0.001, beta = 0.01, n_times = 20,mask:List[int] = None
    ) -> List[int]:
        
        USE_WEIGHT_RANDOM_SAMPLE = os.getenv("USE_WEIGHT_RANDOM_SAMPLE", "NO")
        
        factor = min(0.3, 1 - self.train_y.max() / 100)
        # if self.train_y.max() < 50:
        #     factor = 0
        predict_value = torch.vstack(
            [self.dataset[idx].predict_value for idx in sub_space_idx]
        )
        # print(f'mask: {mask} and {mask is None}')
        # print(f'sub_space_idx: {sub_space_idx}')
        search_space = torch.vstack([self.dataset[idx].feat for idx in sub_space_idx if idx not in mask])
        #####################################
        pseudo_point = zip(search_space, predict_value)
        
        if USE_WEIGHT_RANDOM_SAMPLE == "YES":
            logger.warning("USE_WEIGHT_RANDOM_SAMPLE")
            # 反向权重采样：高产率更可能被丢弃
            yields = np.array([v for _, v in pseudo_point])
            inv_weights = 1.0 / (yields - yields.min() + 1e-6)  # 反向
            inv_weights /= inv_weights.sum()

            keep_size = int(len(pseudo_point) * (1 - factor))
            if keep_size <= 0:
                keep_size = 1
            keep_indices = np.random.choice(
                len(pseudo_point), size=keep_size, replace=False, p=inv_weights
            )

            # 按原始顺序返回保留的样本
            pseudo_point = [pseudo_point[i] for i in sorted(keep_indices)]
        else:
            pseudo_point = sorted(pseudo_point, key=lambda x: x[1],reverse=True)
            pseudo_point = pseudo_point[int(len(pseudo_point) * factor) :]
        
        random_pseudo_point = torch.randperm(len(pseudo_point))
        #  random_pseudo_point = torch.randperm(factor * len(search_space))
        
        pseudo_x = torch.vstack([x[0] for x in pseudo_point])
        pseudo_y = torch.vstack([x[1] for x in pseudo_point])
        pseudo_x = torch.vstack([pseudo_point[i][0] for i in random_pseudo_point])
        pseudo_y = torch.vstack([pseudo_point[i][1] for i in random_pseudo_point])
        pseudo_x = torch.concat([pseudo_x, self.train_x])
        # print(f'pseudo_x device: {pseudo_x.device}')
        # print(f'train_x device: {self.train_x.device}')
        # print(f'pseudo_y device: {pseudo_y.device}')
        # print(f'train_y device: {self.train_y.device}')
        pseudo_y = torch.concat([pseudo_y.to(self.train_y.device), self.train_y])

        self.pseudo_init_model(pseudo_x, pseudo_y, self.model.state_dict())
        if cur_time<(n_times/2):
            acq = qLogExpectedImprovement(self.model, best_f=self.train_y.max(),eta = eta)
        elif cur_time>=(n_times/2):
            acq = qUpperConfidenceBound(self.model, beta=beta)
        #####################################
        if weight is not None:
            candidates = optimize_acqf_discrete_weighted_idx(
                acq, q=n_sample, choices=search_space, weights=weight
            ).tolist()
        else:
            candidates = optimize_acqf_discrete_idx(
                acq, q=n_sample, choices=search_space
            ).tolist()
        candidates = (
            [sub_space_idx[idx] for idx in candidates]
            if isinstance(candidates, list)
            else [sub_space_idx[candidates]]
        )

        return candidates
    def real_exp( self, sub_space_idx: List[int], n_sample: int, weight: torch.Tensor = None,mask:List[int] = None) -> List[int]:

        
        acq = qLogExpectedImprovement(self.model, best_f=self.train_y.max())
        search_space = torch.vstack([self.dataset[idx].feat for idx in sub_space_idx if idx not in mask])
        print('sub space have {} points'.format(search_space.shape[0]))
        if weight is not None:
            candidates = optimize_acqf_discrete_weighted_idx(
                acq, q=n_sample, choices=search_space, weights=weight
            ).tolist()
        else:
            candidates = optimize_acqf_discrete_idx(
                acq, q=n_sample, choices=search_space
            ).tolist()
        # for i in n_sample:
        #     if i < 10:
        #         acq = qLogExpectedImprovement(self.model, best_f=self.train_y.max(),eta=(i+1)*0.001)
        #     else:
        #         acq = qUpperConfidenceBound(self.model, beta=(i-9)*0.1)
        candidates = (
            [sub_space_idx[idx] for idx in candidates]
            if isinstance(candidates, list)
            else [sub_space_idx[candidates]]
        )
        for idx in candidates:
            self.dataset[idx]._is_observed = True

        return candidates
    def sample(
        self, sub_space_idx: List[int], n_sample: int, weight: torch.Tensor = None
    ) -> List[int]:
        acq = qLogExpectedImprovement(self.model, best_f=self.train_y.max())
        search_space = torch.vstack([self.dataset[idx].feat for idx in sub_space_idx])
        # print('sub space have {} points'.format(search_space.shape[0]))
        if weight is not None:
            candidates = optimize_acqf_discrete_weighted_idx(
                acq, q=n_sample, choices=search_space, weights=weight
            ).tolist()
        else:
            candidates = optimize_acqf_discrete_idx(
                acq, q=n_sample, choices=search_space
            ).tolist()
        candidates = (
            [sub_space_idx[idx] for idx in candidates]
            if isinstance(candidates, list)
            else [sub_space_idx[candidates]]
        )
        feats = torch.vstack([self.dataset[idx].feat for idx in candidates])
        self.train_x = torch.cat([self.train_x, feats])
        obj = torch.vstack([self.dataset[idx]._observed_value for idx in candidates])
        self.train_y = torch.cat([self.train_y, obj])
        self.init_model(self.train_x, self.train_y, self.model.state_dict())

        return candidates

    def init_model(
        self,
        train_x: torch.Tensor,
        train_y: torch.Tensor,
        state_dict: Optional[Dict] = None,
    ):
        self.model = SingleTaskGP(
            train_x,
            train_y,
            input_transform=Normalize(d=train_x.shape[-1]),
            outcome_transform=Standardize(m=1),
        ).to(train_x)
        self.mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
        if state_dict is not None:
            self.model.load_state_dict(state_dict)
        fit_gpytorch_mll(self.mll)

    def pseudo_init_model(
        self,
        train_x: torch.Tensor,
        train_y: torch.Tensor,
        state_dict: Optional[Dict] = None,
    ):
        self.model = SingleTaskGP(
            train_x,
            train_y,
            input_transform=Normalize(d=train_x.shape[-1]),
            outcome_transform=Standardize(m=1),
        ).to(train_x)
        self.mll = ExactMarginalLogLikelihood(self.model.likelihood, self.model)
        if state_dict is not None:
            self.model.load_state_dict(state_dict)
            

