import numpy as np

osr_splits = {
    'mnist': {
        'splits': [
            [2, 4, 5, 9, 8, 3],
            [3, 2, 6, 9, 4, 0],
            [5, 8, 3, 2, 4, 6],
            [3, 7, 8, 4, 0, 5],
            [6, 3, 4, 9, 8, 2]
        ],
        'means': [
            [0.1357],
            [0.14090382],
            [0.13819146],
            [0.13822709],
            [0.13702373]
        ],
        'variances': [
            [0.09739873],
            [0.10117427],   
            [0.09905109],
            [0.09929745],
            [0.0984771]
        ]
    },
    
    'svhn': {
        'splits': [
            [5, 3, 7, 2, 8, 6],
            [3, 8, 7, 6, 2, 5],
            [8, 9, 4, 7, 2, 1],
            [3, 8, 2, 5, 0, 6],
            [4, 9, 2, 7, 1, 0]
        ],
        'means': [
            [0.4377, 0.4438, 0.4728],
            [0.4377, 0.4438, 0.4728],
            [0.4377, 0.4438, 0.4728],
            [0.4377, 0.4438, 0.4728],
            [0.4377, 0.4438, 0.4728]
        ],
        'variances': [
            [0.1980, 0.2010, 0.1970],
            [0.1980, 0.2010, 0.1970],
            [0.1980, 0.2010, 0.1970],
            [0.1980, 0.2010, 0.1970],
            [0.1980, 0.2010, 0.1970]
        ]
    },
    
    'cifar10': {
        'splits': [
            [0, 6, 4, 9, 1, 7], # (airplane, frog, deer, truck, automobile, horse)
            [7, 6, 4, 9, 0, 1], # same split as above
            [1, 5, 7, 3, 9, 4], # (automobile, dog, horse, cat, truck, deer)
            [8, 6, 1, 9, 0, 7], # (ship, frog, automobile, truck, airplane, horse)
            [2, 4, 1, 7, 9, 6]  # (bird, deer, automobile, horse, truck, frog)
        ],
        'means': [
            [0.4898587, 0.48060894, 0.4424159 ],
            [0.4898587,  0.48060894, 0.4424159],
            [0.48981622, 0.46766308, 0.42541358],
            [0.49295485, 0.49064583, 0.47182837],
            [0.48378143, 0.46913853, 0.41493976]
        ],
        'variances': [
            [0.06126362, 0.06001796, 0.06984805],
            [0.06126362, 0.06001796, 0.06984805],
            [0.06328715, 0.06151306, 0.06617628],
            [0.06371225, 0.06276005, 0.07345987],
            [0.05917137, 0.05731499, 0.06365602]
        ]
    },
    
    'cifar+10': {
        'splits': [
            [30, 25, 1, 9, 8, 0, 46, 52, 49, 71],
            [41, 9, 49, 40, 73, 60, 48, 30, 95, 71],
            [8, 9, 49, 40, 73, 60, 48, 95, 30, 71],
            [95, 60, 30, 73, 46, 49, 68, 99, 8, 71],
            [33, 2, 3, 97, 46, 21, 64, 63, 88, 43]
        ],
        'means': [
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925]
        ],
        'variances': [
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214]
        ]
    },
    
    'cifar+50': {
        'splits': [
            [27, 94, 29, 77, 88, 26, 69, 48, 75, 5, 59, 93, 39, 57, 45, 40, 78, 20, 98, 47, 66, 70, 91, 76, 41, 83, 99, 32, 53, 72, 2, 95, 21, 73, 84, 68, 35, 11, 55, 60, 30, 25, 1, 9, 8, 0, 46, 52, 49, 71],
            [65, 97, 86, 24, 45, 67, 2, 3, 91, 98, 79, 29, 62, 82, 33, 76, 0, 35, 5, 16, 54, 11, 99, 52, 85, 1, 25, 66, 28, 84, 23, 56, 75, 46, 21, 72, 55, 68, 8, 69, 41, 9, 49, 40, 73, 60, 48, 30, 95, 71],
            [20, 83, 65, 97, 94, 2, 93, 16, 67, 29, 62, 33, 24, 98, 5, 86, 35, 54, 0, 91, 52, 66, 85, 84, 56, 11, 1, 76, 25, 55, 21, 99, 72, 41, 23, 75, 28, 68, 69, 46, 8, 9, 49, 40, 73, 60, 48, 95, 30, 71],
            [92, 82, 77, 64, 5, 33, 62, 56, 70, 0, 20, 28, 67, 14, 84, 53, 91, 29, 85, 2, 52, 83, 75, 35, 11, 21, 72, 98, 55, 1, 41, 76, 25, 66, 69, 9, 48, 54, 40, 23, 95, 60, 30, 73, 46, 49, 68, 99, 8, 71],
            [47, 6, 19, 0, 62, 93, 59, 65, 54, 70, 34, 55, 23, 38, 72, 76, 53, 31, 78, 96, 77, 27, 92, 18, 82, 50, 98, 32, 1, 75, 83, 4, 51, 35, 80, 11, 74, 66, 36, 42, 33, 2, 3, 97, 46, 21, 64, 63, 88, 43]
        ],
        'means': [
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925],
            [0.49642983, 0.5063898, 0.51721925]
        ],
        'variances': [
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214],
            [0.06754102, 0.06635283, 0.07538214]
        ]
    },
    'cat_dog_vs_tiger': {
        'splits': [[3, 5]],
        'means': None,
        'variances': None
    },

    'tiny_imagenet': {
        'splits': [
            [108, 147, 17, 58, 193, 123, 72, 144, 75, 167, 134, 14, 81, 171, 44, 197, 152, 66, 1, 133],
            [198, 161, 91, 59, 57, 134, 61, 184, 90, 35, 29, 23, 199, 38, 133, 19, 186, 18, 85, 67],
            [177, 0, 119, 26, 78, 80, 191, 46, 134, 92, 31, 152, 27, 60, 114, 50, 51, 133, 162, 93],
            [98, 36, 158, 177, 189, 157, 170, 191, 82, 196, 138, 166, 43, 13, 152, 11, 75, 174, 193, 190],
            [95, 6, 145, 153, 0, 143, 31, 23, 189, 81, 20, 21, 89, 26, 36, 170, 102, 177, 108, 169]
        ],
        'means': [
            [0.47541523, 0.44702694, 0.4026736],
            [0.48255503, 0.45423204, 0.39971313],
            [0.4771495, 0.45703092, 0.41323453],
            [0.48007926, 0.44863188, 0.39478093],
            # [0.4914, 0.4822, 0.4465],
            [0.4643234,  0.42005086, 0.36253068]
        ],
        'variances': [
            [0.07662108, 0.07449844, 0.07999679],
            [0.07501277, 0.07167984, 0.07835567],
            [0.0774068, 0.07507991, 0.08436972],
            [0.07526458, 0.07246085, 0.0787185],
            # [0.2023**2, 0.1994**2, 0.2010**2],
            [0.07833668, 0.07275151, 0.07406835]
        ]
    },
    
    'cifar10_full': {
        'splits': [list(range(10))],
        'means': [[0.4912229, 0.48212662, 0.44638672]],
        'variances': [[0.06083429, 0.05912956, 0.06832606]]
    },
    
    'cifar100_full': {
        'splits': [list(range(100))],
        'means': [[0.50715786, 0.48687312, 0.4411894]],
        'variances': [[0.07155167, 0.06576824, 0.07626488]]
    },
    
    'tiny_imagenet_full': {
        'splits': [list(range(200))],
        'means': [[0.47596434, 0.4482439,  0.39279512]],
        'variances': [[0.07636914, 0.07223518, 0.07920632]]
    }
}