"""
----------------------------------------------------------------------------------------------------------------- benchmark: 64 tests ------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                         Min                    Max                   Mean              StdDev                 Median                 IQR            Outliers         OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_sample_M_K[S100-N2-K10-C10-sampling_unoptimized-CUDA]           115.3980 (1.0)       1,098.6000 (1.36)        119.7189 (1.0)       27.2355 (2.23)        117.2720 (1.0)        0.7905 (1.88)      178;610  8,352.9008 (1.0)        8636           1
test_sample_M_K[S100-N2-K10-C10-sampling-CUDA]                       115.7010 (1.00)        916.8340 (1.14)        120.0428 (1.00)      25.2207 (2.06)        117.5270 (1.00)       0.8110 (1.93)      241;672  8,330.3594 (1.00)       8621           1
test_sample_M_K[S100-N4-K10-C10-sampling-CUDA]                       118.2790 (1.02)      1,393.3830 (1.73)        123.2128 (1.03)      37.3837 (3.06)        120.0900 (1.02)       0.6220 (1.48)      145;816  8,116.0408 (0.97)       8409           1
test_sample_M_K[S100-N4-K10-C10-sampling_unoptimized-CUDA]           118.6440 (1.03)      1,429.2190 (1.77)        123.2051 (1.03)      39.3847 (3.22)        120.0515 (1.02)       0.6060 (1.44)      133;752  8,116.5477 (0.97)       8418           1
test_sample_M_K[S100-N8-K10-C10-sampling-CUDA]                       126.8290 (1.10)      1,407.0470 (1.74)        131.2746 (1.10)      38.7937 (3.18)        128.5160 (1.10)       0.6662 (1.58)       94;681  7,617.6211 (0.91)       7845           1
test_sample_M_K[S100-N8-K10-C10-sampling_unoptimized-CUDA]           127.3980 (1.10)      1,379.6720 (1.71)        131.4920 (1.10)      38.6273 (3.16)        128.6800 (1.10)       0.5830 (1.38)      106;552  7,605.0278 (0.91)       7851           1
test_sample_M_K[S100-N1-K10-C10-sampling_unoptimized-CUDA]           130.4810 (1.13)      1,036.2080 (1.28)        133.8942 (1.12)      19.5906 (1.60)        132.2600 (1.13)       0.7660 (1.82)      140;493  7,468.5809 (0.89)       7629           1
test_sample_M_K[S100-N1-K10-C10-sampling-CUDA]                       131.1070 (1.14)        807.7310 (1.0)         133.3965 (1.11)      12.2144 (1.0)         132.3900 (1.13)       0.4917 (1.17)      189;405  7,496.4509 (0.90)       7607           1
test_sample_M_K[S100-N2-K10-C100-sampling-CUDA]                      137.1880 (1.19)      1,114.2850 (1.38)        141.0906 (1.18)      30.1107 (2.47)        138.6765 (1.18)       0.6630 (1.57)      126;400  7,087.6432 (0.85)       7296           1
test_sample_M_K[S100-N2-K10-C100-sampling_unoptimized-CUDA]          137.4190 (1.19)      1,253.4450 (1.55)        141.3135 (1.18)      33.3941 (2.73)        138.8150 (1.18)       0.6205 (1.47)      101;410  7,076.4644 (0.85)       7272           1
test_sample_M_K[S100-N4-K10-C100-sampling-CUDA]                      141.8000 (1.23)      1,424.8650 (1.76)        145.7474 (1.22)      39.3403 (3.22)        142.9610 (1.22)       0.6800 (1.62)       73;439  6,861.1857 (0.82)       7043           1
test_sample_M_K[S100-N4-K10-C100-sampling_unoptimized-CUDA]          141.8990 (1.23)      1,433.2500 (1.77)        145.8467 (1.22)      39.7206 (3.25)        142.9950 (1.22)       0.5920 (1.41)       77;453  6,856.5152 (0.82)       7047           1
test_sample_M_K[S100-N8-K10-C100-sampling_unoptimized-CUDA]          150.3790 (1.30)      1,350.1170 (1.67)        154.7450 (1.29)      41.2286 (3.38)        151.8510 (1.29)       0.7050 (1.67)       66;349  6,462.2455 (0.77)       6646           1
test_sample_M_K[S100-N8-K10-C100-sampling-CUDA]                      150.5440 (1.30)      1,420.3660 (1.76)        154.4711 (1.29)      41.2138 (3.37)        151.6740 (1.29)       0.5270 (1.25)       59;452  6,473.7038 (0.78)       6645           1
test_sample_M_K[S100-N1-K10-C100-sampling-CUDA]                      151.1230 (1.31)        918.7210 (1.14)        153.5306 (1.28)      17.9351 (1.47)        152.1485 (1.30)       0.4530 (1.08)      115;364  6,513.3609 (0.78)       6640           1
test_sample_M_K[S100-N1-K10-C100-sampling_unoptimized-CUDA]          152.7350 (1.32)        926.0940 (1.15)        155.2467 (1.30)      19.4687 (1.59)        153.6930 (1.31)       0.4210 (1.0)       106;388  6,441.3615 (0.77)       6588           1
test_sample_M_K[S100-N2-K100-C10-sampling_unoptimized-CUDA]          611.1100 (5.30)      1,951.0260 (2.42)        624.3404 (5.22)      87.6932 (7.18)        614.2320 (5.24)       1.7760 (4.22)       33;149  1,601.6903 (0.19)       1637           1
test_sample_M_K[S100-N2-K100-C10-sampling-CUDA]                      611.3590 (5.30)      1,931.6740 (2.39)        623.5669 (5.21)      82.6228 (6.76)        613.8040 (5.23)       1.3983 (3.32)       36;154  1,603.6771 (0.19)       1637           1
test_sample_M_K[S100-N2-K100-C100-sampling_unoptimized-CUDA]         656.6800 (5.69)      1,993.4140 (2.47)        670.7363 (5.60)      93.0716 (7.62)        659.8100 (5.63)       1.6480 (3.91)       31;153  1,490.8990 (0.18)       1523           1
test_sample_M_K[S100-N2-K100-C100-sampling-CUDA]                     657.0680 (5.69)      1,953.0810 (2.42)        671.2928 (5.61)      91.8765 (7.52)        659.8690 (5.63)       1.7573 (4.17)       33;138  1,489.6630 (0.18)       1523           1
test_sample_M_K[S1000-N2-K10-C10-sampling-CUDA]                      708.7790 (6.14)      1,953.8350 (2.42)        723.6297 (6.04)      96.5859 (7.91)        711.2410 (6.06)       1.3990 (3.32)       32;147  1,381.9223 (0.17)       1411           1
test_sample_M_K[S1000-N2-K10-C10-sampling_unoptimized-CUDA]          709.0560 (6.14)      1,948.2510 (2.41)        723.6908 (6.04)      95.9928 (7.86)        711.0010 (6.06)       1.3237 (3.14)       35;153  1,381.8056 (0.17)       1411           1
test_sample_M_K[S1000-N4-K10-C10-sampling-CUDA]                      756.7630 (6.56)      1,996.9770 (2.47)        776.5509 (6.49)      98.9206 (8.10)        763.4705 (6.51)       2.8945 (6.88)        36;92  1,287.7456 (0.15)       1320           1
test_sample_M_K[S1000-N4-K10-C10-sampling_unoptimized-CUDA]          756.8150 (6.56)      1,997.4310 (2.47)        776.4418 (6.49)      99.1441 (8.12)        763.4730 (6.51)       2.8900 (6.86)        34;95  1,287.9265 (0.15)       1320           1
test_sample_M_K[S1000-N8-K10-C10-sampling-CUDA]                      844.7340 (7.32)      2,114.3720 (2.62)        862.2258 (7.20)      98.7264 (8.08)        849.3305 (7.24)       2.4190 (5.75)        31;78  1,159.7890 (0.14)       1184           1
test_sample_M_K[S1000-N8-K10-C10-sampling_unoptimized-CUDA]          845.2520 (7.32)      2,080.2060 (2.58)        863.7457 (7.21)     103.4849 (8.47)        849.6925 (7.25)       2.4820 (5.90)        32;93  1,157.7482 (0.14)       1184           1
test_sample_M_K[S1000-N2-K10-C100-sampling_unoptimized-CUDA]         860.2410 (7.45)      2,104.9090 (2.61)        878.7849 (7.34)     107.3992 (8.79)        862.8930 (7.36)       1.2533 (2.98)       38;129  1,137.9349 (0.14)       1163           1
test_sample_M_K[S1000-N2-K10-C100-sampling-CUDA]                     860.3410 (7.46)      2,107.7620 (2.61)        877.6557 (7.33)     104.5316 (8.56)        862.6420 (7.36)       1.1840 (2.81)       36;128  1,139.3990 (0.14)       1163           1
test_sample_M_K[S1000-N1-K10-C10-sampling_unoptimized-CUDA]          893.5190 (7.74)      2,136.6360 (2.65)        911.9366 (7.62)     106.4057 (8.71)        898.1215 (7.66)       1.6910 (4.02)        30;96  1,096.5674 (0.13)       1118           1
test_sample_M_K[S1000-N1-K10-C10-sampling-CUDA]                      894.3480 (7.75)      2,143.3030 (2.65)        912.1185 (7.62)     105.2790 (8.62)        898.0585 (7.66)       1.6900 (4.01)       31;102  1,096.3488 (0.13)       1120           1
test_sample_M_K[S1000-N4-K10-C100-sampling_unoptimized-CUDA]         919.2220 (7.97)      2,153.2540 (2.67)        939.6714 (7.85)     110.4809 (9.05)        923.3940 (7.87)       1.9130 (4.54)       38;106  1,064.2018 (0.13)       1088           1
test_sample_M_K[S1000-N4-K10-C100-sampling-CUDA]                     919.3080 (7.97)      2,214.2450 (2.74)        939.9104 (7.85)     112.2672 (9.19)        923.2930 (7.87)       2.1005 (4.99)        39;97  1,063.9312 (0.13)       1088           1
test_sample_M_K[S1000-N8-K10-C100-sampling_unoptimized-CUDA]       1,007.9870 (8.73)      2,249.2340 (2.78)      1,030.7574 (8.61)     113.4529 (9.29)      1,012.9430 (8.64)       3.5365 (8.40)        35;84    970.1604 (0.12)        992           1
test_sample_M_K[S1000-N8-K10-C100-sampling-CUDA]                   1,009.1320 (8.74)      2,243.8130 (2.78)      1,029.1735 (8.60)     113.6460 (9.30)      1,012.6115 (8.63)       1.9105 (4.54)        32;92    971.6535 (0.12)        992           1
test_sample_M_K[S1000-N1-K10-C100-sampling_unoptimized-CUDA]       1,042.0940 (9.03)      2,292.5320 (2.84)      1,062.4068 (8.87)     114.5244 (9.38)      1,045.2345 (8.91)       1.3565 (3.22)       31;103    941.2590 (0.11)        960           1
test_sample_M_K[S1000-N1-K10-C100-sampling-CUDA]                   1,042.3030 (9.03)      2,277.6670 (2.82)      1,060.1058 (8.85)     107.8228 (8.83)      1,045.2340 (8.91)       1.2130 (2.88)        27;97    943.3021 (0.11)        960           1
test_sample_M_K[S100-N4-K100-C10-sampling-CUDA]                    1,074.1230 (9.31)      2,385.9170 (2.95)      1,095.1679 (9.15)     115.5765 (9.46)      1,077.7660 (9.19)       1.8345 (4.36)        31;94    913.1020 (0.11)        931           1
test_sample_M_K[S100-N4-K100-C10-sampling_unoptimized-CUDA]        1,074.8520 (9.31)      2,382.8540 (2.95)      1,097.2659 (9.17)     119.7913 (9.81)      1,078.4330 (9.20)       2.0012 (4.75)       31;110    911.3561 (0.11)        931           1
test_sample_M_K[S100-N4-K100-C100-sampling-CUDA]                   1,164.7590 (10.09)     2,414.2950 (2.99)      1,188.6677 (9.93)     122.8347 (10.06)     1,169.3170 (9.97)       2.0572 (4.89)        30;86    841.2780 (0.10)        859           1
test_sample_M_K[S100-N4-K100-C100-sampling_unoptimized-CUDA]       1,164.9310 (10.09)     2,390.7120 (2.96)      1,186.3220 (9.91)     114.9815 (9.41)      1,169.0000 (9.97)       2.1053 (5.00)        28;81    842.9415 (0.10)        859           1
test_sample_M_K[S100-N8-K100-C10-sampling-CUDA]                    1,999.2670 (17.32)     3,226.3640 (3.99)      2,032.8876 (16.98)    149.2479 (12.22)     2,003.7740 (17.09)      2.4285 (5.77)        23;72    491.9111 (0.06)        500           1
test_sample_M_K[S100-N8-K100-C10-sampling_unoptimized-CUDA]        1,999.2790 (17.33)     3,233.6770 (4.00)      2,036.8232 (17.01)    157.6587 (12.91)     2,004.1910 (17.09)      3.1303 (7.44)        19;71    490.9606 (0.06)        501           1
test_sample_M_K[S100-N8-K100-C100-sampling_unoptimized-CUDA]       2,172.2440 (18.82)     3,756.3010 (4.65)      2,215.3265 (18.50)    175.8818 (14.40)     2,177.8960 (18.57)      3.4808 (8.27)        20;56    451.4007 (0.05)        461           1
test_sample_M_K[S100-N8-K100-C100-sampling-CUDA]                   2,173.4290 (18.83)     3,406.9590 (4.22)      2,210.7765 (18.47)    158.1013 (12.94)     2,178.4670 (18.58)      2.8230 (6.71)        18;57    452.3298 (0.05)        461           1
test_sample_M_K[S100-N1-K100-C10-sampling_unoptimized-CUDA]        2,359.9020 (20.45)     3,148.6650 (3.90)      2,390.9461 (19.97)    106.9768 (8.76)      2,363.5565 (20.15)      2.3235 (5.52)        30;62    418.2445 (0.05)        424           1
test_sample_M_K[S100-N1-K100-C10-sampling-CUDA]                    2,360.1860 (20.45)     3,246.4230 (4.02)      2,390.5060 (19.97)    107.0442 (8.76)      2,364.2435 (20.16)      2.7100 (6.44)        29;55    418.3215 (0.05)        424           1
test_sample_M_K[S100-N1-K100-C100-sampling-CUDA]                   2,400.5930 (20.80)     3,260.4650 (4.04)      2,434.4330 (20.33)    119.0841 (9.75)      2,404.5135 (20.50)      2.4870 (5.91)        28;58    410.7733 (0.05)        418           1
test_sample_M_K[S100-N1-K100-C100-sampling_unoptimized-CUDA]       2,400.7230 (20.80)     3,327.5100 (4.12)      2,433.1447 (20.32)    121.7276 (9.97)      2,404.8330 (20.51)      2.1872 (5.20)        26;54    410.9908 (0.05)        417           1
test_sample_M_K[S1000-N2-K100-C10-sampling_unoptimized-CUDA]       5,768.6300 (49.99)     7,456.1780 (9.23)      5,886.6562 (49.17)    301.9042 (24.72)     5,780.0325 (49.29)     14.8100 (35.18)       14;38    169.8757 (0.02)        174           1
test_sample_M_K[S1000-N2-K100-C10-sampling-CUDA]                   5,768.6900 (49.99)     7,415.1510 (9.18)      5,882.6116 (49.14)    301.1378 (24.65)     5,778.9145 (49.28)      9.5490 (22.68)       13;36    169.9925 (0.02)        174           1
test_sample_M_K[S1000-N2-K100-C100-sampling-CUDA]                  7,488.6570 (64.89)     9,220.6300 (11.42)     7,643.5589 (63.85)    362.5935 (29.69)     7,507.8195 (64.02)     15.9080 (37.79)       13;29    130.8291 (0.02)        134           1
test_sample_M_K[S1000-N2-K100-C100-sampling_unoptimized-CUDA]      7,495.3380 (64.95)     9,441.5100 (11.69)     7,647.2575 (63.88)    366.4033 (30.00)     7,511.6015 (64.05)     22.5970 (53.67)       14;32    130.7658 (0.02)        134           1
test_sample_M_K[S1000-N4-K100-C10-sampling_unoptimized-CUDA]      10,559.9160 (91.51)    12,516.3090 (15.50)    10,753.9646 (89.83)    436.9131 (35.77)    10,571.6750 (90.15)     76.4780 (181.66)      11;23     92.9890 (0.01)         95           1
test_sample_M_K[S1000-N4-K100-C10-sampling-CUDA]                  10,560.5160 (91.51)    12,717.5520 (15.74)    10,775.2289 (90.00)    461.6329 (37.79)    10,573.5470 (90.16)    198.6235 (471.79)      10;14     92.8055 (0.01)         95           1
test_sample_M_K[S1000-N4-K100-C100-sampling-CUDA]                 14,043.1190 (121.69)   16,205.2400 (20.06)    14,308.2999 (119.52)   540.2521 (44.23)    14,065.3650 (119.94)   183.0275 (434.75)       8;13     69.8895 (0.01)         72           1
test_sample_M_K[S1000-N4-K100-C100-sampling_unoptimized-CUDA]     14,051.5550 (121.77)   16,190.5550 (20.04)    14,334.4979 (119.73)   569.3038 (46.61)    14,072.6005 (120.00)   201.0240 (477.49)       8;14     69.7618 (0.01)         72           1
test_sample_M_K[S1000-N8-K100-C10-sampling_unoptimized-CUDA]      20,160.8040 (174.71)   22,330.3320 (27.65)    20,566.7493 (171.79)   665.2066 (54.46)    20,181.4700 (172.09)   693.2550 (>1000.0)       8;5     48.6222 (0.01)         50           1
test_sample_M_K[S1000-N8-K100-C10-sampling-CUDA]                  20,161.6880 (174.71)   22,372.0680 (27.70)    20,588.9064 (171.98)   685.9877 (56.16)    20,182.6065 (172.10)   702.7570 (>1000.0)       8;4     48.5698 (0.01)         50           1
test_sample_M_K[S1000-N1-K100-C10-sampling_unoptimized-CUDA]      23,258.7180 (201.55)   25,699.2170 (31.82)    23,737.7690 (198.28)   758.4687 (62.10)    23,272.0730 (198.45)   788.0585 (>1000.0)       8;4     42.1270 (0.01)         43           1
test_sample_M_K[S1000-N1-K100-C10-sampling-CUDA]                  23,259.6510 (201.56)   25,466.9020 (31.53)    23,706.9447 (198.02)   751.0281 (61.49)    23,272.5580 (198.45)   630.0245 (>1000.0)       8;5     42.1817 (0.01)         43           1
test_sample_M_K[S1000-N1-K100-C100-sampling_unoptimized-CUDA]     24,196.1440 (209.68)   26,553.5000 (32.87)    24,654.5650 (205.94)   747.6243 (61.21)    24,212.2580 (206.46)   605.2080 (>1000.0)       8;6     40.5604 (0.00)         42           1
test_sample_M_K[S1000-N1-K100-C100-sampling-CUDA]                 24,197.8070 (209.69)   26,422.4180 (32.71)    24,674.4081 (206.10)   735.9563 (60.25)    24,211.3590 (206.45)   808.5170 (>1000.0)       9;4     40.5278 (0.00)         42           1
test_sample_M_K[S1000-N8-K100-C100-sampling_unoptimized-CUDA]     27,160.7890 (235.37)   29,501.5720 (36.52)    27,672.3092 (231.14)   735.3299 (60.20)    27,193.8400 (231.89)   897.7118 (>1000.0)       7;1     36.1372 (0.00)         37           1
test_sample_M_K[S1000-N8-K100-C100-sampling-CUDA]                 27,160.8890 (235.37)   29,466.0220 (36.48)    27,736.6659 (231.68)   779.0481 (63.78)    27,201.1280 (231.95)   943.2115 (>1000.0)       7;0     36.0534 (0.00)         37           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
"""
import pytest
import pytest_benchmark
import torch
import joint_entropy.exact as exact
import joint_entropy.sampling as sampling

import joint_entropy.unoptimized.exact as exact_unoptimized
import joint_entropy.unoptimized.sampling as sampling_unoptimized

import torch_utils


# @pytest.fixture(params=[False, True], ids=["CPU", "CUDA"])
@pytest.fixture(params=[True], ids=["CUDA"])
def torch_device(request):
    use_cuda = request.param
    if use_cuda:
        assert torch.cuda.is_available()
        torch_utils.gc_cuda()
        return torch.device("cuda")
    return torch.device("cpu")


@pytest.fixture(params=[10, 100], name="C", ids=["C10", "C100"])
def C(request):
    return request.param


@pytest.fixture(params=[100], name="B", ids=["B100"])
def B(request):
    return request.param


@pytest.fixture(params=[10000], name="M", ids=["M10000"])
def M(request):
    return request.param


@pytest.fixture(params=[10, 100], name="K", ids=["K10", "K100"])
def K(request):
    return request.param


@pytest.fixture(params=[100, 1000], name="S", ids=["S100", "S1000"])
def S(request):
    return request.param


@pytest.fixture(params=[1, 2, 4, 8], name="N", ids=["N1", "N2", "N4", "N8"])
def N(request):
    return request.param


@pytest.fixture
def samples_M_K(M, K, torch_device):
    return torch.ones((M, K), dtype=torch.float64, device=torch_device)


@pytest.fixture
def probs_B_K_C(B, K, C, torch_device):
    return torch.ones((B, K, C), dtype=torch.float64, device=torch_device)


@pytest.fixture
def probs_N_K_C(N, K, C, torch_device):
    return torch.ones((N, K, C), dtype=torch.float64, device=torch_device)


@pytest.fixture
def result_B_M_C(B, M, C, torch_device):
    return torch.empty((B, M, C), dtype=torch.float64, device=torch_device)


@pytest.fixture(params=[exact, exact_unoptimized], name="module_exact", ids=["exact", "exact_unoptimized"])
def exact_module(request) -> exact:
    return request.param


@pytest.fixture(
    params=[sampling, sampling_unoptimized], name="sampling_module", ids=["sampling", "sampling_unoptimized"]
)
def sampling_module(request) -> sampling:
    return request.param


@pytest.mark.benchmark(warmup=True)
def test_sample_M_K(S, N, K, C, benchmark, probs_N_K_C, sampling_module: sampling):
    benchmark.extra_info["Debug Mode"] = __debug__

    def inner():
        result = sampling_module.sample_M_K(probs_N_K_C, S)

        torch.cuda.synchronize()
        return result.shape

    benchmark(inner)
