import causalchamber.lab as lab
import numpy as np
import pandas as pd


def mapping(ir_1):
    """
    Convert measurements of ir_1 into brightness values for led_2_uv
    - ir_1 goes from 0 to 65535
    - led_2_uv goes from 0 to 4095
    """
    intercept = 0
    slope = 4095.0 / 65535.0
    led_2_uv = intercept + slope * ir_1
    # Trim to valid values [0..4095] (in case you want change the slope / intercept)
    led_2_uv = max(min(4095, led_2_uv), 0)
    return round(led_2_uv)


def sample(chamber, red, led_1_ir, led_1_uv):
    """
    Given a chamber, and values for red and led_1_ir, generate values for ir_1 and ir_2
    """

    # 1. Set 'red' and 'led_1_ir', take measurement of ir_1
    batch = chamber.new_batch()
    batch.set("red", red)
    batch.set("led_1_ir", led_1_ir)
    batch.set("led_1_uv", led_1_uv)
    batch.measure(n=1)

    # 2. Take measurement of ir_1
    ir_1 = batch.submit().ir_1.values[0]

    # 3. Set 'led_2_uv' according to the mapping(ir_1) that goes from ir_1 -> led_2_uv
    led_2_uv = mapping(ir_1)
    batch = chamber.new_batch()
    batch.set("led_2_uv", led_2_uv)
    batch.measure(n=1)

    # 4. Take measurement of ir_2
    ir_2 = batch.submit().ir_2.values[0]

    # Return all values
    return (red, led_1_ir, led_1_uv, ir_1, led_2_uv, ir_2)


def generate_dataset(m, r, random_state=1957981623):
    """
    Sample
      - red (confounder)
      - led_1_ir (instrument)
      - led_1_uv (to modulate extra noise on ir_1)
    Generate (using the chamber)
      - ir_1 (treatment)
      - ir_2 (effect)
      - led_2_uv (mediator, reported for debugging if you need it)
    """

    # >>>>>>> Sample your values for red, led_1_ir, led_2_ir <<<<<<<<<<<
    rng = np.random.default_rng(random_state)
    red_values = rng.integers(256, size=m * r)  # values in [0..255]
    led_1_ir_values = np.tile(
        np.arange(0, m), r
    )  # values in [0...4095] -> values above 500 and ir_1 will start to saturate (max out at 65535)
    # check that non of the values are above 4095
    assert np.max(led_1_ir_values) <= 4095

    led_1_uv_values = np.zeros(
        m * r, dtype=int
    )  # values in [0...4095] 1 means led_1_uv = 0 so there's no additional noise

    # Connect to the chamber
    chamber = lab.Chamber(
        chamber_id="lt-aeon-dlpv",
        config="standard",
        credentials_file="./causal_chambers/.credentials",
    )  # Path to the credentials file

    # Collect data for ir_1, ir_2 and led_2_uv (for visualizing your mapping, just in case)
    ir_1_values, ir_2_values, led_2_uv_values = [], [], []
    print("Running experiment")
    for i, (red, led_1_ir, led_1_uv) in enumerate(
        zip(red_values, led_1_ir_values, led_1_uv_values)
    ):
        print(f"  collecting obs. {i+1}/{m*r}", end="\r")
        (_, _, _, ir_1, led_2_uv, ir_2) = sample(chamber, red, led_1_ir, led_1_uv)
        ir_1_values.append(ir_1)
        led_2_uv_values.append(led_2_uv)
        ir_2_values.append(ir_2)
    print()

    # Compose and return dataframe
    return pd.DataFrame(
        {
            "red": red_values,
            "led_1_ir": led_1_ir_values,
            "led_1_uv": led_1_uv_values,
            "ir_1": ir_1_values,
            "ir_2": ir_2_values,
            "led_2_uv": led_2_uv_values,
        }
    )


df = generate_dataset(m=750, r=16)
# save df to csv
df.to_csv("./causal_chambers/df_standard_16.csv", index=False)
