import numpy as np

class ClassSplit:
	def __init__(self, args):
		args.logger.print(f"Using class split: {args.class_order}")
		self.dataset = args.dataset

		# if class_order == 0, use the default consecutive order
		if args.class_order == 0:
			if args.dataset == 'mnist':
				self.split = np.arange(10)
			elif args.dataset == 'svhn':
				self.split = np.arange(10)
			elif args.dataset == 'cifar10':
				self.split = np.arange(10)
			elif args.dataset == 'cifar100':
				self.split = np.arange(100)
			elif args.dataset == 'timgnet':
				self.split = np.arange(200)
			elif args.dataset == 'imgnet380':
				self.split = np.array([2, 3, 4, 7, 9, 14, 16, 17, 18, 19, 20, 24, 25, 26, 28, 30, 31, 32, 35, 40, 41, 42, 44, 45, 46, 48, 49, 50, 54, 55, 61, 62, 63, 65, 67, 68, 69, 70, 71, 75, 76, 78, 79, 81, 83, 85, 86, 87, 88, 89, 90, 92, 93, 94, 96, 97, 98, 99, 104, 105, 106, 107, 108, 109, 111, 112, 113, 114, 115, 116, 117, 118, 122, 123, 124, 125, 126, 128, 129, 130, 131, 132, 133, 135, 138, 139, 140, 142, 143, 144, 145, 147, 148, 149, 150, 151, 153, 154, 155, 159, 160, 163, 167, 170, 171, 172, 173, 175, 176, 178, 179, 180, 183, 186, 187, 188, 190, 191, 195, 196, 197, 199, 200, 201, 203, 206, 207, 208, 209, 210, 220, 229, 230, 231, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 248, 249, 256, 260, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 291, 294, 301, 308, 309, 311, 313, 314, 315, 319, 323, 325, 329, 338, 341, 345, 347, 349, 353, 354, 365, 367, 372, 382, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 445, 447, 448, 457, 458, 462, 463, 466, 467, 470, 471, 474, 480, 485, 488, 492, 496, 498, 499, 508, 509, 511, 517, 525, 526, 542, 543, 557, 562, 565, 567, 568, 570, 573, 576, 604, 605, 612, 614, 619, 621, 625, 627, 635, 645, 652, 655, 675, 677, 678, 682, 683, 687, 704, 707, 716, 720, 731, 733, 734, 735, 737, 739, 744, 747, 758, 760, 761, 765, 768, 774, 779, 781, 786, 801, 806, 808, 811, 815, 817, 821, 826, 837, 839, 842, 845, 849, 850, 853, 862, 866, 873, 874, 877, 879, 887, 888, 890, 899, 900, 909, 910, 917, 923, 924, 928, 929, 932, 935, 938, 945, 947, 950, 951, 954, 957, 962, 963, 964, 967, 970, 972, 973, 975, 978, 988])
			else:
				raise NotImplementedError()
		elif args.class_order == 1:
			if args.dataset == 'svhn':
				self.split = np.array([3, 9, 1, 8, 0, 2, 6, 4, 5, 7])
			elif args.dataset == 'cifar10':
				self.split = np.array([3, 9, 1, 8, 0, 2, 6, 4, 5, 7])
			elif args.dataset == 'cifar100':
				self.split = np.array([49, 98, 97, 53, 48, 62, 89, 23, 82, 13, 40, 35, 71, 59, 34, 95, 67, 11, 27, 7, 47, 85, 36, 70, 51, 32, 60, 16, 29, 84, 39, 8, 17, 42, 72, 18, 15, 55, 83, 10, 37, 99, 66, 22, 14, 57, 24, 38, 80, 0, 52, 88, 77, 3, 50, 6, 41, 44, 93, 9, 96, 81, 45, 58, 5, 64, 86, 2, 78, 68, 75, 56, 46, 91, 43, 20, 87, 1, 33, 28, 19, 61, 30, 74, 65, 79, 63, 25, 12, 21, 31, 90, 69, 73, 4, 94, 76, 92, 54, 26])
			elif args.dataset == 'timgnet':
				self.split = np.array([117, 8, 183, 39, 40, 47, 75, 133, 193, 28, 130, 31, 98, 119, 188, 161, 57, 92, 54, 134, 6, 71, 147, 70, 139, 68, 77, 149, 17, 87, 132, 184, 59, 52, 194, 187, 159, 196, 166, 50, 63, 62, 141, 20, 126, 99, 19, 182, 164, 34, 2, 13, 97, 78, 151, 85, 150, 74, 111, 11, 61, 83, 41, 24, 55, 101, 110, 88, 60, 14, 65, 4, 51, 5, 30, 171, 158, 84, 15, 10, 46, 165, 118, 140, 90, 186, 107, 148, 180, 42, 152, 64, 189, 109, 136, 106, 91, 66, 178, 73, 172, 29, 25, 103, 44, 108, 191, 36, 72, 76, 82, 167, 160, 199, 9, 155, 175, 174, 179, 144, 177, 197, 170, 81, 121, 113, 58, 21, 89, 0, 69, 157, 137, 1, 26, 37, 153, 124, 143, 95, 23, 105, 79, 48, 32, 3, 190, 38, 135, 80, 198, 33, 53, 56, 49, 112, 125, 156, 131, 116, 129, 67, 162, 173, 123, 12, 181, 7, 192, 169, 185, 104, 100, 138, 168, 195, 43, 93, 45, 35, 22, 142, 146, 16, 127, 86, 128, 114, 27, 120, 145, 163, 102, 122, 18, 154, 94, 96, 115, 176])
			elif args.dataset == 'imgnet380':
				self.split = np.array([509, 436, 627, 79, 426, 117, 888, 975, 70, 99, 492, 432, 288, 294, 113, 18, 440, 437, 191, 862, 94, 280, 419, 471, 17, 932, 201, 761, 605, 31, 234, 49, 116, 236, 716, 837, 947, 249, 187, 69, 129, 210, 108, 466, 25, 203, 462, 173, 408, 414, 573, 928, 526, 63, 142, 386, 149, 687, 954, 781, 106, 126, 744, 576, 83, 910, 382, 114, 739, 806, 54, 970, 542, 384, 179, 209, 817, 16, 151, 409, 418, 635, 132, 131, 557, 338, 143, 438, 425, 87, 496, 427, 170, 935, 281, 46, 264, 570, 26, 801, 285, 208, 28, 242, 747, 431, 392, 237, 269, 417, 567, 850, 190, 319, 135, 839, 9, 424, 41, 873, 230, 314, 439, 275, 734, 67, 325, 329, 186, 104, 238, 988, 244, 284, 98, 474, 144, 243, 309, 279, 923, 372, 625, 768, 195, 345, 917, 733, 420, 93, 89, 463, 619, 385, 78, 707, 86, 760, 183, 410, 458, 808, 267, 389, 720, 973, 65, 125, 130, 167, 154, 7, 97, 50, 159, 899, 229, 853, 197, 287, 199, 96, 938, 866, 90, 498, 978, 347, 467, 909, 435, 278, 758, 231, 301, 404, 124, 365, 354, 393, 391, 525, 508, 139, 967, 964, 683, 163, 353, 765, 122, 111, 133, 731, 180, 962, 341, 289, 88, 890, 433, 349, 488, 397, 887, 879, 412, 196, 115, 423, 963, 311, 470, 92, 842, 612, 62, 815, 511, 421, 945, 138, 779, 480, 323, 682, 678, 737, 614, 485, 68, 811, 148, 76, 270, 402, 604, 265, 220, 291, 416, 849, 273, 239, 268, 445, 565, 42, 14, 30, 107, 406, 845, 313, 645, 235, 429, 394, 272, 246, 55, 207, 123, 61, 652, 877, 499, 175, 413, 562, 274, 430, 245, 826, 147, 543, 4, 415, 266, 105, 35, 240, 704, 282, 972, 308, 735, 675, 109, 145, 517, 178, 388, 315, 655, 457, 399, 774, 3, 2, 155, 900, 176, 950, 85, 40, 621, 75, 874, 447, 821, 171, 924, 19, 401, 400, 128, 387, 140, 568, 277, 160, 172, 44, 929, 263, 81, 256, 150, 48, 248, 112, 200, 398, 405, 188, 951, 24, 241, 411, 390, 118, 448, 677, 957, 206, 422, 786, 45, 286, 434, 153, 271, 283, 20, 403, 71, 367, 260, 32])
			else:
				raise NotImplementedError()
		elif args.class_order == 2:
			if args.dataset == 'mnist':
				self.split = np.arange(10)
			elif args.dataset == 'svhn':
				self.split = np.array([6, 0, 2, 8, 1, 9, 7, 3, 5, 4])
			elif args.dataset == 'cifar10':
				self.split = np.array([6, 0, 2, 8, 1, 9, 7, 3, 5, 4])
			elif args.dataset == 'cifar100':
				self.split = np.array([64, 79, 89, 9, 88, 3, 26, 94, 61, 62, 73, 69, 83, 8, 75, 23, 45, 92, 74, 1, 84, 71, 96, 52, 7, 95, 2, 5, 70, 28, 77, 60, 43, 22, 91, 78, 34, 80, 48, 51, 58, 37, 6, 25, 85, 97, 40, 27, 32, 98, 36, 21, 39, 31, 15, 49, 66, 72, 67, 24, 20, 93, 87, 54, 90, 76, 99, 30, 53, 29, 82, 57, 65, 4, 19, 11, 14, 41, 16, 86, 59, 68, 35, 55, 38, 17, 33, 50, 81, 63, 42, 10, 18, 56, 44, 13, 46, 0, 47, 12])
			elif args.dataset == 'timgnet':
				self.split = np.array([121, 66, 149, 189, 103, 195, 0, 72, 179, 46, 7, 159, 70, 65, 123, 76, 54, 37, 186, 62, 96, 136, 124, 69, 181, 6, 57, 125, 161, 81, 134, 147, 132, 59, 20, 50, 93, 71, 117, 33, 135, 47, 36, 120, 94, 73, 75, 14, 102, 60, 113, 142, 175, 115, 184, 185, 152, 63, 12, 198, 24, 26, 119, 109, 165, 87, 144, 64, 48, 52, 21, 95, 116, 187, 137, 10, 84, 162, 90, 131, 88, 150, 110, 41, 146, 106, 86, 127, 151, 16, 107, 67, 129, 140, 4, 172, 39, 23, 51, 183, 197, 31, 157, 188, 171, 58, 13, 153, 32, 98, 173, 130, 97, 80, 133, 163, 53, 44, 141, 145, 155, 176, 156, 138, 22, 68, 112, 3, 174, 2, 42, 25, 29, 104, 170, 178, 193, 126, 122, 30, 196, 199, 182, 128, 91, 56, 49, 111, 83, 78, 89, 61, 192, 34, 148, 191, 180, 190, 74, 167, 158, 139, 1, 101, 166, 143, 28, 8, 43, 105, 38, 177, 118, 55, 108, 19, 5, 168, 15, 79, 160, 45, 169, 164, 85, 82, 77, 27, 40, 99, 92, 194, 18, 11, 154, 35, 100, 17, 9, 114])
			elif args.dataset == 'imgnet380':
				self.split = np.array([140, 567, 419, 284, 365, 155, 353, 410, 99, 98, 242, 167, 731, 94, 436, 391, 142, 737, 173, 492, 707, 744, 866, 879, 434, 973, 899, 815, 41, 928, 568, 122, 426, 108, 416, 675, 392, 890, 768, 153, 720, 924, 435, 289, 431, 414, 947, 525, 387, 427, 415, 655, 806, 573, 842, 323, 97, 349, 393, 70, 910, 237, 341, 401, 635, 282, 845, 480, 433, 78, 418, 150, 208, 76, 849, 69, 645, 463, 652, 975, 389, 604, 837, 417, 149, 329, 394, 498, 372, 923, 178, 466, 49, 288, 83, 421, 163, 627, 853, 25, 409, 704, 314, 367, 474, 605, 93, 171, 32, 186, 338, 733, 7, 196, 17, 179, 786, 14, 275, 402, 280, 172, 210, 877, 129, 231, 133, 354, 18, 106, 200, 113, 139, 440, 107, 425, 386, 199, 562, 274, 420, 687, 265, 716, 269, 96, 126, 271, 50, 2, 677, 245, 138, 26, 826, 61, 203, 614, 277, 390, 267, 48, 88, 87, 682, 957, 145, 206, 147, 241, 109, 954, 68, 485, 345, 325, 385, 470, 154, 400, 104, 438, 260, 424, 683, 9, 938, 758, 612, 281, 951, 86, 159, 309, 739, 678, 75, 65, 403, 272, 874, 422, 779, 315, 747, 382, 71, 734, 988, 458, 116, 917, 405, 195, 429, 432, 508, 243, 839, 801, 496, 183, 970, 263, 887, 44, 130, 950, 900, 31, 20, 238, 24, 4, 499, 286, 399, 439, 187, 125, 945, 412, 244, 557, 821, 256, 808, 308, 471, 781, 761, 46, 765, 963, 935, 457, 978, 397, 621, 774, 249, 398, 929, 264, 619, 207, 967, 45, 170, 197, 517, 114, 40, 89, 570, 79, 811, 285, 411, 287, 760, 229, 117, 16, 105, 278, 191, 28, 462, 55, 230, 625, 294, 131, 311, 115, 445, 148, 291, 234, 817, 90, 850, 151, 236, 112, 132, 888, 160, 190, 176, 447, 542, 408, 268, 123, 862, 388, 92, 565, 448, 3, 30, 85, 124, 319, 180, 118, 42, 932, 488, 135, 67, 111, 35, 962, 526, 283, 467, 246, 279, 266, 406, 201, 384, 19, 128, 313, 235, 81, 404, 62, 188, 430, 175, 54, 509, 239, 543, 63, 240, 209, 143, 909, 248, 423, 511, 347, 576, 972, 437, 735, 144, 964, 270, 301, 413, 873, 220, 273])
			else:
				raise NotImplementedError()
		elif args.class_order == 3:
			if args.dataset == 'mnist':
				self.split = np.arange(10)
			elif args.dataset == 'svhn':
				self.split = np.array([2, 6, 1, 5, 9, 8, 0, 4, 3, 7])
			elif args.dataset == 'cifar10':
				self.split = np.array([2, 6, 1, 5, 9, 8, 0, 4, 3, 7])
			elif args.dataset == 'cifar100':
				self.split = np.array([97, 1, 48, 88, 58, 46, 87, 18, 35, 71, 45, 6, 31, 69, 21, 96, 9, 44, 14, 68, 98, 27, 56, 38, 13, 63, 47, 57, 22, 64, 8, 73, 78, 94, 52, 4, 23, 28, 85, 2, 19, 10, 92, 7, 93, 76, 42, 34, 49, 80, 40, 37, 66, 83, 33, 99, 36, 12, 41, 39, 75, 25, 3, 95, 16, 0, 29, 53, 60, 11, 24, 82, 86, 32, 91, 43, 65, 89, 15, 81, 17, 62, 90, 54, 51, 20, 55, 30, 77, 59, 50, 5, 74, 84, 67, 79, 70, 61, 72, 26])
			elif args.dataset == 'timgnet':
				self.split = np.array([156, 137, 7, 123, 154, 38, 121, 40, 43, 6, 76, 129, 91, 18, 12, 149, 162, 189, 145, 107, 5, 85, 78, 111, 191, 71, 146, 87, 155, 92, 48, 49, 21, 34, 23, 187, 179, 110, 102, 186, 105, 184, 29, 90, 159, 79, 28, 108, 89, 128, 57, 96, 194, 54, 55, 167, 141, 51, 67, 0, 177, 99, 26, 173, 1, 163, 122, 115, 30, 101, 170, 198, 134, 69, 61, 58, 192, 171, 185, 37, 124, 15, 114, 132, 181, 9, 157, 83, 19, 131, 73, 86, 153, 138, 32, 8, 33, 165, 42, 180, 44, 168, 188, 81, 64, 166, 24, 172, 142, 95, 35, 161, 160, 13, 119, 199, 39, 100, 97, 125, 52, 195, 65, 158, 197, 127, 46, 4, 175, 20, 56, 190, 41, 174, 151, 84, 182, 183, 109, 75, 3, 93, 106, 136, 50, 17, 74, 10, 150, 60, 112, 164, 193, 53, 14, 169, 152, 82, 116, 80, 63, 77, 120, 117, 11, 72, 31, 104, 113, 68, 144, 88, 178, 47, 16, 27, 176, 98, 148, 94, 25, 126, 143, 62, 118, 70, 140, 45, 66, 130, 196, 147, 133, 59, 103, 36, 139, 2, 135, 22])
			elif args.dataset == 'imgnet380':
				self.split = np.array([25, 3, 62, 201, 92, 392, 153, 123, 815, 826, 716, 433, 430, 405, 277, 113, 625, 517, 480, 887, 687, 44, 950, 417, 186, 414, 839, 245, 388, 655, 145, 199, 76, 474, 614, 463, 288, 801, 178, 849, 341, 229, 354, 438, 175, 367, 308, 155, 866, 445, 284, 837, 274, 142, 604, 675, 65, 237, 645, 426, 410, 131, 69, 279, 183, 135, 87, 243, 111, 18, 747, 973, 154, 397, 329, 786, 508, 61, 287, 964, 850, 404, 842, 94, 41, 71, 45, 526, 283, 496, 240, 542, 605, 2, 195, 220, 231, 806, 873, 390, 309, 440, 765, 436, 200, 49, 286, 385, 81, 720, 313, 402, 264, 278, 677, 972, 85, 235, 467, 909, 117, 234, 733, 447, 311, 115, 945, 978, 256, 429, 400, 347, 246, 143, 808, 242, 108, 273, 78, 567, 17, 190, 406, 122, 365, 79, 7, 372, 349, 112, 319, 399, 19, 389, 398, 988, 208, 196, 543, 268, 511, 130, 425, 63, 409, 150, 899, 947, 9, 129, 353, 86, 263, 963, 431, 140, 951, 862, 761, 187, 109, 466, 323, 118, 16, 386, 270, 525, 419, 128, 499, 203, 272, 781, 171, 238, 98, 149, 739, 394, 917, 28, 315, 821, 627, 50, 423, 938, 132, 424, 557, 46, 271, 89, 492, 457, 413, 421, 682, 338, 910, 683, 437, 239, 68, 173, 774, 760, 439, 206, 54, 67, 975, 707, 207, 75, 97, 248, 133, 734, 879, 241, 488, 32, 179, 83, 853, 88, 210, 99, 42, 31, 967, 124, 70, 249, 420, 159, 265, 900, 434, 280, 147, 416, 160, 923, 291, 758, 735, 928, 282, 197, 576, 116, 929, 393, 180, 779, 612, 387, 874, 126, 418, 570, 4, 260, 90, 817, 924, 40, 167, 269, 151, 744, 888, 191, 236, 652, 435, 737, 957, 267, 244, 125, 458, 954, 704, 932, 877, 294, 30, 382, 845, 20, 768, 93, 471, 24, 427, 448, 470, 48, 562, 172, 935, 403, 285, 138, 26, 14, 619, 384, 890, 635, 731, 230, 565, 345, 170, 163, 188, 568, 391, 422, 105, 301, 485, 411, 509, 462, 314, 573, 289, 401, 962, 114, 412, 325, 281, 811, 176, 621, 144, 148, 408, 415, 55, 678, 209, 970, 35, 275, 139, 266, 498, 96, 106, 104, 107, 432])
			else:
				raise NotImplementedError()
		elif args.class_order == 4:
			if args.dataset == 'mnist':
				self.split = np.arange(10)
			elif args.dataset == 'svhn':
				self.split = np.array([1, 5, 7, 2, 0, 3, 4, 6, 8, 9])
			elif args.dataset == 'cifar10':
				self.split = np.array([1, 5, 7, 2, 0, 3, 4, 6, 8, 9])
			elif args.dataset == 'cifar100':
				self.split = np.array([34, 31, 97, 47, 83, 59, 39, 4, 32, 44, 26, 73, 45, 33, 56, 87, 82, 23, 88, 10, 51, 57, 65, 84, 43, 37, 9, 74, 28, 24, 90, 25, 60, 80, 5, 64, 63, 62, 40, 19, 49, 21, 77, 95, 99, 16, 12, 14, 70, 54, 53, 38, 8, 72, 18, 68, 15, 94, 36, 7, 1, 69, 2, 61, 98, 75, 85, 11, 17, 76, 22, 27, 92, 71, 3, 0, 66, 42, 96, 67, 35, 30, 46, 81, 48, 93, 79, 6, 13, 86, 20, 91, 78, 50, 89, 41, 52, 55, 29, 58])
			elif args.dataset == 'timgnet':
				self.split = np.array([187, 110, 38, 174, 97, 189, 39, 109, 122, 37, 42, 65, 101, 188, 134, 191, 153, 194, 3, 147, 78, 129, 52, 1, 185, 85, 22, 60, 98, 51, 155, 145, 24, 103, 2, 73, 139, 74, 18, 175, 48, 105, 46, 31, 161, 171, 14, 117, 69, 167, 12, 163, 25, 121, 13, 177, 16, 102, 56, 142, 107, 151, 53, 44, 62, 169, 176, 150, 67, 86, 91, 82, 5, 156, 128, 70, 149, 179, 144, 19, 146, 160, 21, 49, 0, 35, 119, 6, 141, 131, 94, 30, 162, 159, 76, 45, 17, 100, 118, 84, 66, 158, 15, 64, 54, 27, 89, 123, 193, 4, 80, 96, 58, 152, 93, 168, 108, 59, 113, 29, 34, 182, 83, 55, 11, 10, 111, 136, 133, 28, 192, 79, 127, 180, 140, 95, 68, 106, 61, 41, 157, 195, 90, 183, 130, 7, 125, 124, 40, 63, 116, 186, 199, 148, 120, 104, 75, 138, 178, 43, 181, 8, 143, 137, 20, 33, 99, 170, 184, 32, 87, 92, 154, 166, 88, 198, 26, 115, 190, 71, 72, 77, 50, 132, 165, 135, 164, 9, 47, 126, 57, 112, 114, 172, 197, 81, 36, 173, 23, 196])
			elif args.dataset == 'imgnet380':
				self.split = np.array([145, 98, 171, 645, 419, 30, 7, 16, 496, 853, 248, 170, 439, 142, 471, 237, 144, 153, 781, 278, 90, 422, 474, 488, 385, 877, 416, 277, 652, 964, 806, 315, 266, 267, 448, 744, 241, 929, 434, 951, 988, 463, 273, 808, 923, 319, 935, 200, 973, 188, 338, 542, 604, 147, 957, 40, 427, 187, 148, 31, 716, 704, 735, 423, 499, 269, 415, 87, 117, 143, 862, 92, 932, 758, 210, 619, 175, 236, 480, 126, 402, 683, 413, 45, 842, 436, 309, 275, 9, 274, 517, 50, 196, 249, 32, 230, 54, 178, 71, 421, 887, 132, 815, 573, 621, 866, 392, 420, 42, 195, 850, 231, 576, 888, 20, 149, 112, 747, 291, 543, 954, 431, 928, 349, 397, 707, 180, 485, 94, 282, 313, 106, 492, 967, 19, 975, 817, 774, 49, 24, 389, 155, 924, 682, 125, 811, 458, 437, 526, 128, 197, 69, 62, 25, 388, 627, 242, 265, 28, 116, 83, 85, 511, 35, 191, 97, 63, 133, 341, 760, 286, 963, 845, 400, 403, 779, 417, 462, 404, 418, 186, 739, 67, 325, 849, 283, 75, 562, 570, 308, 372, 972, 206, 839, 826, 508, 429, 978, 129, 438, 733, 557, 909, 678, 384, 183, 285, 18, 406, 173, 731, 435, 239, 426, 151, 950, 768, 900, 401, 268, 430, 354, 289, 568, 68, 131, 399, 314, 405, 135, 625, 107, 160, 387, 118, 123, 565, 44, 938, 179, 89, 509, 311, 353, 445, 677, 271, 55, 61, 207, 408, 386, 86, 246, 635, 425, 432, 105, 612, 154, 140, 163, 122, 412, 99, 873, 14, 288, 46, 124, 260, 270, 263, 220, 279, 111, 410, 347, 440, 159, 567, 391, 240, 108, 899, 367, 734, 264, 114, 2, 88, 65, 238, 150, 4, 76, 879, 765, 272, 962, 345, 235, 945, 3, 890, 720, 394, 605, 329, 245, 96, 203, 323, 411, 244, 176, 190, 139, 281, 79, 301, 256, 243, 172, 414, 167, 687, 78, 467, 48, 209, 525, 280, 498, 229, 675, 424, 109, 393, 737, 470, 409, 655, 234, 390, 433, 17, 874, 130, 138, 614, 466, 917, 970, 287, 294, 26, 199, 447, 457, 41, 93, 761, 786, 113, 398, 947, 104, 208, 70, 821, 201, 284, 81, 910, 837, 801, 365, 382, 115])
			else:
				raise NotImplementedError()
		else:
			raise NotImplementedError()
		args.logger.print("Class Order:", self.split)

	def relabel(self, dataset):
		assert len(np.unique(dataset.targets)) == len(self.split)

		target_copy = dataset.targets.copy()
		if not isinstance(target_copy, np.ndarray):
			target_copy = np.array(target_copy)

		# unique = np.unique(target_copy) # sorted in ascending order

		for new_y, y in enumerate(self.split):
			idx = dataset.targets == y
			target_copy[idx] = new_y

		dataset.targets = target_copy.copy()

		if self.dataset == 'timgnet' or self.dataset == 'imgnet380':
			samples_copy = dataset.samples.copy()
			for i, (img, lab) in enumerate(dataset.samples):
				# print('a', dataset.samples[i])
				temp = list(samples_copy[i])
				# temp[1] = self.split[lab]
				temp[1] = np.where(self.split == lab)[0].item()
				samples_copy[i] = tuple(temp)
				# print('b', samples_copy[i])

			dataset.samples = samples_copy.copy()
		
		return dataset
