import random
import torch
import numpy as np
import logging

_logger = logging.getLogger(__name__)


def sample_op(search_space, probabilities=None):
    return np.random.choice(search_space, p=probabilities)


class Sampler:

    def __init__(self):
        pass
    
    def sample_op(self):
        raise NotImplementedError
    
    def set_seed(self, seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)


class HierarchicalSampler(Sampler):

    def __init__(self, primitives, similarities, search_space):

        self.primitives = primitives
        self.similarities = similarities
        self.search_space = search_space
        self.experts = [s.split("_", 1)[1] for s in self.search_space if "reuse" in s]
        self.calculate_probabilities()

    def calculate_probabilities(self):

        self.expert_sampling_prob = self.similarities.softmax(dim=0).cpu().numpy()
        self.retention_prob = torch.sigmoid(self.similarities).cpu().numpy()
        
        # Calculate the probabilities for the categorical distribution
        reuse = self.expert_sampling_prob * self.retention_prob**2
        adapt = self.expert_sampling_prob * self.retention_prob * (1 - self.retention_prob)
        new = np.array([0.5 * np.sum(self.expert_sampling_prob * (1 - self.retention_prob))])
        skip = np.array([0.5 * np.sum(self.expert_sampling_prob * (1 - self.retention_prob))])

        search_space_prob = np.concatenate([reuse, adapt, new, skip])
        # FIXME: This is a workaround to avoid numerical erros
        # because of which probabilities don't sum to 1 but has a small error due to floating point.
        self.search_space_prob = search_space_prob / search_space_prob.sum()

        assert len(self.search_space_prob) == len(self.search_space)

    def sample_op(self, mode, seed=None, **kwargs):
        if seed is not None:
            self.set_seed(seed)
        assert mode in ["explore", "exploit"]
        probabilities = self.search_space_prob if mode == "exploit" else None
        return sample_op(self.search_space, probabilities=probabilities)
    

class HierarchicalSamplerV2(Sampler):

    def __init__(self, primitives, similarities, search_space):

        self.primitives = primitives
        self.similarities: torch.Tensor = similarities
        self.search_space = search_space
        self.experts = [s.split("_", 1)[1] for s in self.search_space if "reuse" in s]
        self.calculate_probabilities()

    def calculate_probabilities(self):

        # Calculate the pseudo similarity for the new/skip experts
        new_skip_similarities = -1. * self.similarities.max()
        similarities = torch.cat([self.similarities, new_skip_similarities.unsqueeze(0)], dim=0)
        self.expert_sampling_prob = similarities.softmax(dim=0).cpu().numpy()
        self.retention_prob = torch.sigmoid(similarities[:-1]).cpu().numpy()
        
        # Calculate the probabilities for the categorical distribution
        reuse = self.expert_sampling_prob[:-1] * self.retention_prob
        adapt = self.expert_sampling_prob[:-1] * (1 - self.retention_prob)
        new = np.array([0.5 * self.expert_sampling_prob[-1]])
        skip = np.array([0.5 * self.expert_sampling_prob[-1]])

        search_space_prob = np.concatenate([reuse, adapt, new, skip])
        # FIXME: This is a workaround to avoid numerical erros
        # because of which probabilities don't sum to 1 but has a small error due to floating point.
        self.search_space_prob = search_space_prob / search_space_prob.sum()

        assert len(self.search_space_prob) == len(self.search_space)

    def sample_op(self, mode, seed=None, **kwargs):
        if seed is not None:
            self.set_seed(seed)
        assert mode in ["explore", "exploit"]
        probabilities = self.search_space_prob if mode == "exploit" else None
        return sample_op(self.search_space, probabilities=probabilities)


class UniformSampler(Sampler):

    def __init__(self, primitives, search_space):
        self.primitives = primitives
        self.search_space = search_space

    def sample_op(self, seed=None, **kwargs):
        if seed is not None:
            self.set_seed(seed)
        return sample_op(self.search_space)
