import torch
import itertools

def distinct(n, m):
    if n == 0: 
        return [(1, 2, 3, 4, 5, 6, 7, 8), (9, 18, 27, 36, 45, 54, 63, 72), (1, 2, 9, 10, 11, 18, 19, 20)][m]
    elif n == 1: return [(0, 2, 3, 4, 5, 6, 7, 8), (10, 19, 28, 37, 46, 55, 64, 73), (0, 2, 9, 10, 11, 18, 19, 20)][m]
    elif n == 2: return [(0, 1, 3, 4, 5, 6, 7, 8), (11, 20, 29, 38, 47, 56, 65, 74), (0, 1, 9, 10, 11, 18, 19, 20)][m]
    elif n == 3: return [(0, 1, 2, 4, 5, 6, 7, 8), (12, 21, 30, 39, 48, 57, 66, 75), (4, 5, 12, 13, 14, 21, 22, 23)][m]
    elif n == 4: return [(0, 1, 2, 3, 5, 6, 7, 8), (13, 22, 31, 40, 49, 58, 67, 76), (3, 5, 12, 13, 14, 21, 22, 23)][m]
    elif n == 5: return [(0, 1, 2, 3, 4, 6, 7, 8), (14, 23, 32, 41, 50, 59, 68, 77), (3, 4, 12, 13, 14, 21, 22, 23)][m]
    elif n == 6: return [(0, 1, 2, 3, 4, 5, 7, 8), (15, 24, 33, 42, 51, 60, 69, 78), (7, 8, 15, 16, 17, 24, 25, 26)][m]
    elif n == 7: return [(0, 1, 2, 3, 4, 5, 6, 8), (16, 25, 34, 43, 52, 61, 70, 79), (6, 8, 15, 16, 17, 24, 25, 26)][m]
    elif n == 8: return [(0, 1, 2, 3, 4, 5, 6, 7), (17, 26, 35, 44, 53, 62, 71, 80), (6, 7, 15, 16, 17, 24, 25, 26)][m]
    elif n == 9: return [(10, 11, 12, 13, 14, 15, 16, 17), (0, 18, 27, 36, 45, 54, 63, 72), (0, 1, 2, 10, 11, 18, 19, 20)][m]
    elif n == 10: return [(9, 11, 12, 13, 14, 15, 16, 17), (1, 19, 28, 37, 46, 55, 64, 73), (0, 1, 2, 9, 11, 18, 19, 20)][m]
    elif n == 11: return [(9, 10, 12, 13, 14, 15, 16, 17), (2, 20, 29, 38, 47, 56, 65, 74), (0, 1, 2, 9, 10, 18, 19, 20)][m]
    elif n == 12: return [(9, 10, 11, 13, 14, 15, 16, 17), (3, 21, 30, 39, 48, 57, 66, 75), (3, 4, 5, 13, 14, 21, 22, 23)][m]
    elif n == 13: return [(9, 10, 11, 12, 14, 15, 16, 17), (4, 22, 31, 40, 49, 58, 67, 76), (3, 4, 5, 12, 14, 21, 22, 23)][m]
    elif n == 14: return [(9, 10, 11, 12, 13, 15, 16, 17), (5, 23, 32, 41, 50, 59, 68, 77), (3, 4, 5, 12, 13, 21, 22, 23)][m]
    elif n == 15: return [(9, 10, 11, 12, 13, 14, 16, 17), (6, 24, 33, 42, 51, 60, 69, 78), (6, 7, 8, 16, 17, 24, 25, 26)][m]
    elif n == 16: return [(9, 10, 11, 12, 13, 14, 15, 17), (7, 25, 34, 43, 52, 61, 70, 79), (6, 7, 8, 15, 17, 24, 25, 26)][m]
    elif n == 17: return [(9, 10, 11, 12, 13, 14, 15, 16), (8, 26, 35, 44, 53, 62, 71, 80), (6, 7, 8, 15, 16, 24, 25, 26)][m]
    elif n == 18: return [(19, 20, 21, 22, 23, 24, 25, 26), (0, 9, 27, 36, 45, 54, 63, 72), (0, 1, 2, 9, 10, 11, 19, 20)][m]
    elif n == 19: return [(18, 20, 21, 22, 23, 24, 25, 26), (1, 10, 28, 37, 46, 55, 64, 73), (0, 1, 2, 9, 10, 11, 18, 20)][m]
    elif n == 20: return [(18, 19, 21, 22, 23, 24, 25, 26), (2, 11, 29, 38, 47, 56, 65, 74), (0, 1, 2, 9, 10, 11, 18, 19)][m]
    elif n == 21: return [(18, 19, 20, 22, 23, 24, 25, 26), (3, 12, 30, 39, 48, 57, 66, 75), (3, 4, 5, 12, 13, 14, 22, 23)][m]
    elif n == 22: return [(18, 19, 20, 21, 23, 24, 25, 26), (4, 13, 31, 40, 49, 58, 67, 76), (3, 4, 5, 12, 13, 14, 21, 23)][m]
    elif n == 23: return [(18, 19, 20, 21, 22, 24, 25, 26), (5, 14, 32, 41, 50, 59, 68, 77), (3, 4, 5, 12, 13, 14, 21, 22)][m]
    elif n == 24: return [(18, 19, 20, 21, 22, 23, 25, 26), (6, 15, 33, 42, 51, 60, 69, 78), (6, 7, 8, 15, 16, 17, 25, 26)][m]
    elif n == 25: return [(18, 19, 20, 21, 22, 23, 24, 26), (7, 16, 34, 43, 52, 61, 70, 79), (6, 7, 8, 15, 16, 17, 24, 26)][m]
    elif n == 26: return [(18, 19, 20, 21, 22, 23, 24, 25), (8, 17, 35, 44, 53, 62, 71, 80), (6, 7, 8, 15, 16, 17, 24, 25)][m]
    elif n == 27: return [(28, 29, 30, 31, 32, 33, 34, 35), (0, 9, 18, 36, 45, 54, 63, 72), (28, 29, 36, 37, 38, 45, 46, 47)][m]
    elif n == 28: return [(27, 29, 30, 31, 32, 33, 34, 35), (1, 10, 19, 37, 46, 55, 64, 73), (27, 29, 36, 37, 38, 45, 46, 47)][m]
    elif n == 29: return [(27, 28, 30, 31, 32, 33, 34, 35), (2, 11, 20, 38, 47, 56, 65, 74), (27, 28, 36, 37, 38, 45, 46, 47)][m]
    elif n == 30: return [(27, 28, 29, 31, 32, 33, 34, 35), (3, 12, 21, 39, 48, 57, 66, 75), (31, 32, 39, 40, 41, 48, 49, 50)][m]
    elif n == 31: return [(27, 28, 29, 30, 32, 33, 34, 35), (4, 13, 22, 40, 49, 58, 67, 76), (30, 32, 39, 40, 41, 48, 49, 50)][m]
    elif n == 32: return [(27, 28, 29, 30, 31, 33, 34, 35), (5, 14, 23, 41, 50, 59, 68, 77), (30, 31, 39, 40, 41, 48, 49, 50)][m]
    elif n == 33: return [(27, 28, 29, 30, 31, 32, 34, 35), (6, 15, 24, 42, 51, 60, 69, 78), (34, 35, 42, 43, 44, 51, 52, 53)][m]
    elif n == 34: return [(27, 28, 29, 30, 31, 32, 33, 35), (7, 16, 25, 43, 52, 61, 70, 79), (33, 35, 42, 43, 44, 51, 52, 53)][m]
    elif n == 35: return [(27, 28, 29, 30, 31, 32, 33, 34), (8, 17, 26, 44, 53, 62, 71, 80), (33, 34, 42, 43, 44, 51, 52, 53)][m]
    elif n == 36: return [(37, 38, 39, 40, 41, 42, 43, 44), (0, 9, 18, 27, 45, 54, 63, 72), (27, 28, 29, 37, 38, 45, 46, 47)][m]
    elif n == 37: return [(36, 38, 39, 40, 41, 42, 43, 44), (1, 10, 19, 28, 46, 55, 64, 73), (27, 28, 29, 36, 38, 45, 46, 47)][m]
    elif n == 38: return [(36, 37, 39, 40, 41, 42, 43, 44), (2, 11, 20, 29, 47, 56, 65, 74), (27, 28, 29, 36, 37, 45, 46, 47)][m]
    elif n == 39: return [(36, 37, 38, 40, 41, 42, 43, 44), (3, 12, 21, 30, 48, 57, 66, 75), (30, 31, 32, 40, 41, 48, 49, 50)][m]
    elif n == 40: return [(36, 37, 38, 39, 41, 42, 43, 44), (4, 13, 22, 31, 49, 58, 67, 76), (30, 31, 32, 39, 41, 48, 49, 50)][m]
    elif n == 41: return [(36, 37, 38, 39, 40, 42, 43, 44), (5, 14, 23, 32, 50, 59, 68, 77), (30, 31, 32, 39, 40, 48, 49, 50)][m]
    elif n == 42: return [(36, 37, 38, 39, 40, 41, 43, 44), (6, 15, 24, 33, 51, 60, 69, 78), (33, 34, 35, 43, 44, 51, 52, 53)][m]
    elif n == 43: return [(36, 37, 38, 39, 40, 41, 42, 44), (7, 16, 25, 34, 52, 61, 70, 79), (33, 34, 35, 42, 44, 51, 52, 53)][m]
    elif n == 44: return [(36, 37, 38, 39, 40, 41, 42, 43), (8, 17, 26, 35, 53, 62, 71, 80), (33, 34, 35, 42, 43, 51, 52, 53)][m]
    elif n == 45: return [(46, 47, 48, 49, 50, 51, 52, 53), (0, 9, 18, 27, 36, 54, 63, 72), (27, 28, 29, 36, 37, 38, 46, 47)][m]
    elif n == 46: return [(45, 47, 48, 49, 50, 51, 52, 53), (1, 10, 19, 28, 37, 55, 64, 73), (27, 28, 29, 36, 37, 38, 45, 47)][m]
    elif n == 47: return [(45, 46, 48, 49, 50, 51, 52, 53), (2, 11, 20, 29, 38, 56, 65, 74), (27, 28, 29, 36, 37, 38, 45, 46)][m]
    elif n == 48: return [(45, 46, 47, 49, 50, 51, 52, 53), (3, 12, 21, 30, 39, 57, 66, 75), (30, 31, 32, 39, 40, 41, 49, 50)][m]
    elif n == 49: return [(45, 46, 47, 48, 50, 51, 52, 53), (4, 13, 22, 31, 40, 58, 67, 76), (30, 31, 32, 39, 40, 41, 48, 50)][m]
    elif n == 50: return [(45, 46, 47, 48, 49, 51, 52, 53), (5, 14, 23, 32, 41, 59, 68, 77), (30, 31, 32, 39, 40, 41, 48, 49)][m]
    elif n == 51: return [(45, 46, 47, 48, 49, 50, 52, 53), (6, 15, 24, 33, 42, 60, 69, 78), (33, 34, 35, 42, 43, 44, 52, 53)][m]
    elif n == 52: return [(45, 46, 47, 48, 49, 50, 51, 53), (7, 16, 25, 34, 43, 61, 70, 79), (33, 34, 35, 42, 43, 44, 51, 53)][m]
    elif n == 53: return [(45, 46, 47, 48, 49, 50, 51, 52), (8, 17, 26, 35, 44, 62, 71, 80), (33, 34, 35, 42, 43, 44, 51, 52)][m]
    elif n == 54: return [(55, 56, 57, 58, 59, 60, 61, 62), (0, 9, 18, 27, 36, 45, 63, 72), (55, 56, 63, 64, 65, 72, 73, 74)][m]
    elif n == 55: return [(54, 56, 57, 58, 59, 60, 61, 62), (1, 10, 19, 28, 37, 46, 64, 73), (54, 56, 63, 64, 65, 72, 73, 74)][m]
    elif n == 56: return [(54, 55, 57, 58, 59, 60, 61, 62), (2, 11, 20, 29, 38, 47, 65, 74), (54, 55, 63, 64, 65, 72, 73, 74)][m]
    elif n == 57: return [(54, 55, 56, 58, 59, 60, 61, 62), (3, 12, 21, 30, 39, 48, 66, 75), (58, 59, 66, 67, 68, 75, 76, 77)][m]
    elif n == 58: return [(54, 55, 56, 57, 59, 60, 61, 62), (4, 13, 22, 31, 40, 49, 67, 76), (57, 59, 66, 67, 68, 75, 76, 77)][m]
    elif n == 59: return [(54, 55, 56, 57, 58, 60, 61, 62), (5, 14, 23, 32, 41, 50, 68, 77), (57, 58, 66, 67, 68, 75, 76, 77)][m]
    elif n == 60: return [(54, 55, 56, 57, 58, 59, 61, 62), (6, 15, 24, 33, 42, 51, 69, 78), (61, 62, 69, 70, 71, 78, 79, 80)][m]
    elif n == 61: return [(54, 55, 56, 57, 58, 59, 60, 62), (7, 16, 25, 34, 43, 52, 70, 79), (60, 62, 69, 70, 71, 78, 79, 80)][m]
    elif n == 62: return [(54, 55, 56, 57, 58, 59, 60, 61), (8, 17, 26, 35, 44, 53, 71, 80), (60, 61, 69, 70, 71, 78, 79, 80)][m]
    elif n == 63: return [(64, 65, 66, 67, 68, 69, 70, 71), (0, 9, 18, 27, 36, 45, 54, 72), (54, 55, 56, 64, 65, 72, 73, 74)][m]
    elif n == 64: return [(63, 65, 66, 67, 68, 69, 70, 71), (1, 10, 19, 28, 37, 46, 55, 73), (54, 55, 56, 63, 65, 72, 73, 74)][m]
    elif n == 65: return [(63, 64, 66, 67, 68, 69, 70, 71), (2, 11, 20, 29, 38, 47, 56, 74), (54, 55, 56, 63, 64, 72, 73, 74)][m]
    elif n == 66: return [(63, 64, 65, 67, 68, 69, 70, 71), (3, 12, 21, 30, 39, 48, 57, 75), (57, 58, 59, 67, 68, 75, 76, 77)][m]
    elif n == 67: return [(63, 64, 65, 66, 68, 69, 70, 71), (4, 13, 22, 31, 40, 49, 58, 76), (57, 58, 59, 66, 68, 75, 76, 77)][m]
    elif n == 68: return [(63, 64, 65, 66, 67, 69, 70, 71), (5, 14, 23, 32, 41, 50, 59, 77), (57, 58, 59, 66, 67, 75, 76, 77)][m]
    elif n == 69: return [(63, 64, 65, 66, 67, 68, 70, 71), (6, 15, 24, 33, 42, 51, 60, 78), (60, 61, 62, 70, 71, 78, 79, 80)][m]
    elif n == 70: return [(63, 64, 65, 66, 67, 68, 69, 71), (7, 16, 25, 34, 43, 52, 61, 79), (60, 61, 62, 69, 71, 78, 79, 80)][m]
    elif n == 71: return [(63, 64, 65, 66, 67, 68, 69, 70), (8, 17, 26, 35, 44, 53, 62, 80), (60, 61, 62, 69, 70, 78, 79, 80)][m]
    elif n == 72: return [(73, 74, 75, 76, 77, 78, 79, 80), (0, 9, 18, 27, 36, 45, 54, 63), (54, 55, 56, 63, 64, 65, 73, 74)][m]
    elif n == 73: return [(72, 74, 75, 76, 77, 78, 79, 80), (1, 10, 19, 28, 37, 46, 55, 64), (54, 55, 56, 63, 64, 65, 72, 74)][m]
    elif n == 74: return [(72, 73, 75, 76, 77, 78, 79, 80), (2, 11, 20, 29, 38, 47, 56, 65), (54, 55, 56, 63, 64, 65, 72, 73)][m]
    elif n == 75: return [(72, 73, 74, 76, 77, 78, 79, 80), (3, 12, 21, 30, 39, 48, 57, 66), (57, 58, 59, 66, 67, 68, 76, 77)][m]
    elif n == 76: return [(72, 73, 74, 75, 77, 78, 79, 80), (4, 13, 22, 31, 40, 49, 58, 67), (57, 58, 59, 66, 67, 68, 75, 77)][m]
    elif n == 77: return [(72, 73, 74, 75, 76, 78, 79, 80), (5, 14, 23, 32, 41, 50, 59, 68), (57, 58, 59, 66, 67, 68, 75, 76)][m]
    elif n == 78: return [(72, 73, 74, 75, 76, 77, 79, 80), (6, 15, 24, 33, 42, 51, 60, 69), (60, 61, 62, 69, 70, 71, 79, 80)][m]
    elif n == 79: return [(72, 73, 74, 75, 76, 77, 78, 80), (7, 16, 25, 34, 43, 52, 61, 70), (60, 61, 62, 69, 70, 71, 78, 80)][m]
    elif n == 80: return [(72, 73, 74, 75, 76, 77, 78, 79), (8, 17, 26, 35, 44, 53, 62, 71), (60, 61, 62, 69, 70, 71, 78, 79)][m]

def pretrain(device, n):
    x = torch.zeros(tuple([n+1]*(n-1)+[n])).to(device)
    # x = torch.zeros(10, 10, 10, 10, 10, 10, 10, 10, 9)

    for i in range(1, n+1):

        # Generate all permutations of [1, 2, 3, ..., 9]
        n_i = list(range(1, n+1))
        n_i.remove(i)
        base_permutations = torch.tensor(list(itertools.permutations(n_i))).to(device)

        # Generate all masks for zero substitution
        masks = torch.tensor([[int(b) for b in f"{i:0{n-1}b}"] for i in range(2**(n-1))]).to(device)

        # Apply the masks to generate all possible indices
        indices = (base_permutations.unsqueeze(1) * masks.unsqueeze(0)).reshape(-1, (n-1))

        # Update the tensor using these indices
        x[indices[:, 0], indices[:, 1], indices[:, 2], indices[:, 3], indices[:, 4], indices[:, 5], indices[:, 6], indices[:, 7], i-1] = 1
        # x[indices[:, 0], indices[:, 1], indices[:, 2]] = i
 
    return x