import random
import torch
import numpy as np
from tqdm import tqdm

def set_seed(seed: int):

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # torch.use_deterministic_algorithms()

def auto_or_float(value):
    if value.startswith('auto'):
        return value
    try:
        return float(value)
    except ValueError:
        raise argparse.ArgumentTypeError(f"Invalid value: {value}. Must be a float or 'auto'.")


class InverseTriangleDistribution:
    def __init__(self, left, right, peak, left_area):
        self.left = left
        self.right = right
        self.peak = peak
        self.left_area = left_area
        self.left_a = -2 * left_area / ((peak - left) ** 2)
        self.left_b = 2 * left_area * peak / ((peak - left) ** 2)
        self.right_a = 2 * (1 - left_area) / ((right - peak) ** 2)
        self.right_b = -2 * (1 - left_area) * peak / ((right - peak) ** 2)

        self.x_left_param = ((0.5 * self.left_a) * (self.peak ** 2 - self.left ** 2) + self.left_b * (self.peak - self.left))
        self.x_right_param = ((0.5 * self.right_a) * (self.right ** 2 - self.peak ** 2) + self.right_b * (self.right - self.peak))

        self.left_samples_param = self.left_b ** 2 + (self.left_a ** 2) * (self.left ** 2) + 2 * self.left_a * self.left_b * self.left
        self.right_samples_param = self.right_b ** 2 + (self.right_a ** 2) * (self.peak ** 2) + 2 * self.right_a * self.right_b * self.peak
        
    def sample(self, n_samples):
        u = np.random.rand(n_samples)
        left_mask = u < self.left_area
        right_mask = ~left_mask
        # print(left_mask.sum(), right_mask.sum())
        
        # Calculate samples for the left side
        x_left = np.random.rand(left_mask.sum()) * self.x_left_param
        left_samples = - (self.left_b - np.sqrt(self.left_samples_param + 2 * self.left_a * x_left)) / self.left_a
        # Calculate samples for the right side
        # print(np.min(left_samples))
        # print(np.max(left_samples))
        x_right = np.random.rand(right_mask.sum()) * self.x_right_param
        right_samples = - (self.right_b - np.sqrt(self.right_samples_param + 2 * self.right_a * x_right)) / self.right_a
        # print(np.min(right_samples))
        # print(np.max(right_samples))
        samples = np.empty(n_samples)
        samples[left_mask] = left_samples
        samples[right_mask] = right_samples
        return samples

