import os
import random

import numpy as np  # type: ignore
import torch as th  # type: ignore


def seed_everything(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)

    th.manual_seed(seed)
    th.cuda.manual_seed(seed)
    th.backends.cudnn.deterministic = True


def nash_scalarization(reward: th.Tensor, w: th.Tensor) -> th.Tensor:
    if reward.dim() == 1:
        return th.prod(reward).item()

    elif reward.dim() == 2:
        return th.prod(reward, axis=-1)


def min_scalarization(reward: th.Tensor, w: th.Tensor) -> th.Tensor:
    if reward.dim() == 1:
        return th.min(reward).item()

    elif reward.dim() == 2:
        return th.min(reward, axis=-1).values


def owa_scalarization(r: th.Tensor, w: th.Tensor) -> th.Tensor:
    w = w.to(r.device)
    if r.dim() == 1:
        return th.dot(th.sort(r)[0], w)

    elif r.dim() == 2:
        return th.sum(th.sort(r, axis=-1)[0] * w, axis=-1)


def fair_ratio(
    vec_returns: th.Tensor, fair_efficient_threshold: float = 2
) -> th.Tensor:
    """
        vec_returns: 2d tensor of size (n_evals, n_objectives) representing 
            the vectorized returns after n_evals executions of the policy    
        fair_efficient_threshold : threshold for which an objective is 
            considered to be solved to the optimal   
    """
    
    return (vec_returns >= fair_efficient_threshold).all(axis=1).sum() / len(
        vec_returns
    )

def min_proportion(vec_returns:th.Tensor, demand: th.Tensor ):
    return (vec_returns / demand).min(axis = 1).values

def max_min_proportion(demand: th.Tensor, total_bag_size):
    if(total_bag_size < demand.shape[1]): raise Exception("Bag size not sufficient to compute max proportion")
    len_demand,n_objectives = demand.size(0), demand.size(1)
    unit_prop = 1/demand
    max_min_prop = 1/demand
    
    for _ in range(n_objectives, total_bag_size):
        min_indices = th.argmin(max_min_prop, axis=1) 
        max_min_prop[th.arange(len_demand), min_indices] += unit_prop[th.arange(len_demand), min_indices]

    return th.clamp(th.min(max_min_prop, axis = 1).values, max = 1)
