import numpy as np
import torch.nn as nn

from myutils import print_debug
import torch
import torch.nn.functional as F
import os

import math
from transformers import LlamaTokenizer, LlamaForCausalLM
from myutils import visualize_output

# get the student model
# info_tuples = [(2, 3, 5), (4, 5, 5), (6, 7, 5)] # better than the following one
# info_tuples = [(2, 3, 5), (2, 5, 5), (2, 7, 5)] 

# info_tuples = [(28, 29, 5)] # not bad, the same as [(2, 3, 5)]

# info_tuples = [(i, i+1, 5) for i in range(2, 30, 2)]
# info_tuples += [(i, i, 5) for i in range(2, 30, 2)] # add this can be slightly better (6.3), not a big deal I think. The key maybe the data efficiency
# info_tuples = [(i, i, 5) for i in range(2, 30, 2)] # almost equivalent to LoRA. For 0.005 ratio of arxiv-math, the ppl is 3.1, so not big difference as before.
# info_tuples = [(i, i, 5) for i in range(0, 32, 1)] # equivalent to LoRA, ppl is 3.15 for 0.005 ratio of arxiv-math.

class InfoTupleManager():
    def __init__(self, info_tuples_type=0):
        self.info_tuples_type = info_tuples_type
        
    def get_info_tuples(self, rank=5, start_layer=2, end_layer=30, step=2):
        if self.info_tuples_type == 6:
            # vanilla LoRA on the info_tuples_type=0 tuple
            info_tuples = [(i, i, rank) for i in range(32)] # lora baseline
        
        elif self.info_tuples_type == 0: # need to store 18 layers
            # predict last layer
            assert step == 2, 'step should be 2'
            info_tuples = [(i, i+1, rank) for i in range(start_layer, end_layer, step)]
            # eqivalent to [
            #  (2,3,rank), (4,5,rank), (6,7,rank), (8,9,rank), (10,11,rank), (12,13,rank), (14,15,rank), 
            # (16,17,rank), (18,19,rank), (20,21,rank), (22,23,rank), (24,25,rank), (26,27,rank), (28,29,rank)  
            #]
            
        elif self.info_tuples_type == 100:
            # predict last layer + current modification
            assert step == 2, 'step should be 2'
            info_tuples = [(i, i+1, rank) for i in range(start_layer, end_layer, step)]
            info_tuples += [(i, i, rank) for i in range(start_layer, end_layer, step)]
            
        elif self.info_tuples_type == 200:
            # predict next two layers
            # assert step >= 3, 'step should be >= 3'
            step = 3
            info_tuples = [(i, i+1, rank) for i in range(start_layer, end_layer, step)]
            info_tuples += [(i, i+2, rank) for i in range(start_layer, end_layer, step)]
        
        elif self.info_tuples_type == 300:
            # lora baseline
            assert step == 2, 'step should be 2'
            info_tuples = [(i, i, rank) for i in range(start_layer, end_layer, step)]
        # type > 900 -> self defined
        
        elif self.info_tuples_type == 901:
            info_tuples = [(2, 3, rank),  (4, 5, rank), (6, 7, rank), (8, 9, rank), (10, 11, rank), 
                           (12, 13, rank), (12, 14, rank), (15, 16, rank), (15, 17, rank), (18, 19, rank), (18, 20, rank), 
                           (21, 22, rank), (21, 23, rank), (24, 25, rank), (24, 26, rank), (27, 28, rank), (27, 29, rank)]
            
        elif self.info_tuples_type == 902:
            info_tuples = [(2, 3, rank), (4, 5, rank), (6, 7, rank), (8, 9, rank), (10, 11, rank), 
                           (12, 13, rank), (12, 14, rank), 
                           (15, 16, rank), (15, 17, rank), (15, 18, rank), (15, 19, rank), (15, 20, rank), (15, 21, rank),  
                           (22, 23, rank), (22, 24, rank), (22, 25, rank), (22, 26, rank), (22, 27, rank), (22, 28, rank), (22, 29, rank)]
        elif self.info_tuples_type == 903:
            info_tuples = [(14, 15, rank), (14, 16, rank), (14, 17, rank), (14, 18, rank), (14, 19, rank), (14, 20, rank), (14, 21, rank), 
                           (14, 22, rank), (14, 23, rank), (14, 24, rank), (14, 25, rank), (14, 26, rank), (14, 27, rank), (14, 28, rank), 
                           (14, 29, rank)]
            
        elif self.info_tuples_type == 913: # start from front -> nan for (0.01 & 0.001)
            info_tuples = [(2, 3, rank), (2, 4, rank), (2, 5, rank), (2, 6, rank), (2, 7, rank), (2, 8, rank), (2, 9, rank),
                            (2, 10, rank), (2, 11, rank), (2, 12, rank), (2, 13, rank), (2, 14, rank), (2, 15, rank), (2, 16, rank),
                            (2, 17, rank)]
            
        elif self.info_tuples_type == 914: # start from middle -> 17 for (0.01 & 0.001)
            info_tuples = [(8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), (8, 13, rank), (8, 14, rank), (8, 15, rank),
                            (8, 16, rank), (8, 17, rank), (8, 18, rank), (8, 19, rank), (8, 20, rank), (8, 21, rank), (8, 22, rank),
                            (8, 23, rank)]
        
        elif self.info_tuples_type == 915: # start from middle -> 15 for (0.01 & 0.001)
            info_tuples = [(2, 3, rank), (4, 5, rank), (6, 7, rank), 
                           (8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), (8, 13, rank), (8, 14, rank), (8, 15, rank), (8, 16, rank),
                           (17, 18, rank), (17, 19, rank), (17, 20, rank), (17, 21, rank), (17, 22, rank), (17, 23, rank), (17, 24, rank), (17, 25, rank), (17, 26, rank), (17, 27, rank), (17, 28, rank), (17, 29, rank)]
        
        elif self.info_tuples_type == 916: # 10.3
            info_tuples = [
                (2, 3, rank), (4, 5, rank), (6, 7, rank), (8, 9, rank),
                (10, 11, rank), (10, 12, rank), (10, 13, rank), (10, 14, rank), 
                (15, 16, rank), (15, 17, rank), (15, 18, rank), (15, 19, rank), (15, 20, rank), (15, 21, rank), 
                (22, 23, rank), (22, 24, rank), (22, 25, rank), (22, 26, rank), (22, 27, rank), (22, 28, rank), (22, 29, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (4, 4, rank), (6, 6, rank), (8, 8, rank), (10, 10, rank), (15, 15, rank), (22, 22, rank), (30, 30, rank), (31, 31, rank)
            ]
        
        elif self.info_tuples_type == 917: # 17.4
            info_tuples = [
                (2, 3, rank), (4, 5, rank), (6, 7, rank), (8, 9, rank),
                (10, 11, rank), (10, 12, rank), (10, 13, rank), (10, 14, rank), (10, 15, rank), (10, 16, rank),
                (17, 18, rank), (17, 19, rank), (17, 20, rank), (17, 21, rank), (17, 22, rank), (17, 23, rank), (17, 24, rank), (17, 25, rank), (17, 26, rank), (17, 27, rank), (17, 28, rank), (17, 29, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (4, 4, rank), (6, 6, rank), (8, 8, rank), (10, 10, rank), (17, 17, rank), (30, 30, rank), (31, 31, rank)
            ]
            
        elif self.info_tuples_type == 918: # 11.3
            info_tuples = [
                (2, 3, rank), (4, 5, rank), (6, 7, rank), 
                (8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), 
                (13, 14, rank), (13, 15, rank), (13, 16, rank), (13, 17, rank), (13, 18, rank), (13, 19, rank),
                (20, 21, rank), (20, 22, rank), (20, 23, rank), (20, 24, rank), (20, 25, rank), (20, 26, rank), (20, 27, rank), (20, 28, rank), (20, 29, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (4, 4, rank), (6, 6, rank), (8, 8, rank), (13, 13, rank), (20, 20, rank), (30, 30, rank), (31, 31, rank)
            ]
        
        elif self.info_tuples_type == 908: # 11.3 -> 4.66
            info_tuples = [
                (2, 3, rank), (4, 5, rank), (6, 7, rank), 
                (8, 9, rank), (8, 10, rank), (8, 11, rank), 
                (12, 13, rank), (12, 14, rank), (12, 15, rank), (12, 16, rank), 
                (17, 18, rank), (17, 19, rank), (17, 20, rank), (17, 21, rank), (17, 22, rank), 
                (23, 24, rank), (23, 25, rank), (23, 26, rank), (23, 27, rank), (23, 28, rank), (23, 29, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (4, 4, rank), (6, 6, rank), (8, 8, rank), (12, 12, rank), (17, 17, rank), (23, 23, rank)
            ]
        
        
        elif self.info_tuples_type == 919: # 331, so you should not let 6 predicts too much
            info_tuples = [
                (2, 3, rank), (4, 5, rank), 
                (6, 7, rank), (6, 8, rank), (6, 9, rank), (6, 10, rank),
                (11, 12, rank), (11, 13, rank), (11, 14, rank), (11, 15, rank), (11, 16, rank), (11, 17, rank), (11, 18, rank), 
                (19, 20, rank), (19, 21, rank), (19, 22, rank), (19, 23, rank), (19, 24, rank), (19, 25, rank), (19, 26, rank), (19, 27, rank), (19, 28, rank), (19, 29, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (4, 4, rank), (6, 6, rank), (11, 11, rank), (19, 19, rank), (30, 30, rank), (31, 31, rank)
            ]
            
        elif self.info_tuples_type == 920: # 12.9
            info_tuples = [
                (1, 2, rank), (3, 4, rank), (5, 6, rank),
                (7, 8, rank), (7, 9, rank), (7, 10, rank), (7, 11, rank),
                (12, 13, rank), (12, 14, rank), (12, 15, rank), (12, 16, rank), (12, 17, rank), (12, 18, rank), (12, 19, rank),
                (20, 21, rank), (20, 22, rank), (20, 23, rank), (20, 24, rank), (20, 25, rank), (20, 26, rank), (20, 27, rank), (20, 28, rank), (20, 29, rank),
                (0, 0, rank), (1, 1, rank), (3, 3, rank), (5, 5, rank), (7, 7, rank), (12, 12, rank), (20, 20, rank), (30, 30, rank), (31, 31, rank)
            ]
        
        elif self.info_tuples_type == 921: # 13.8
            info_tuples = [
                (1, 2, rank), (3, 4, rank), (5, 6, rank),
                (7, 8, rank), (7, 9, rank), (7, 10, rank), (7, 11, rank),
                (12, 13, rank), (12, 14, rank), (12, 15, rank), (12, 16, rank), (12, 17, rank), (12, 18, rank), (12, 19, rank),
                (20, 21, rank), (20, 22, rank), (20, 23, rank), (20, 24, rank), (20, 25, rank), (20, 26, rank), (20, 27, rank), (20, 28, rank), (20, 29, rank), (20, 30, rank),
                (0, 0, rank), (1, 1, rank), (3, 3, rank), (5, 5, rank), (7, 7, rank), (12, 12, rank), (20, 20, rank), (31, 31, rank)
            ]
            
        elif self.info_tuples_type == 922: # 11.5
            info_tuples = [
                (7, 8, rank), (7, 9, rank), (7, 10, rank), (7, 11, rank),
                (12, 13, rank), (12, 14, rank), (12, 15, rank), (12, 16, rank), (12, 17, rank), (12, 18, rank), (12, 19, rank),
                (20, 21, rank), (20, 22, rank), (20, 23, rank), (20, 24, rank), (20, 25, rank), (20, 26, rank), (20, 27, rank), (20, 28, rank), (20, 29, rank), (20, 30, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (3, 3, rank), (4, 4, rank), (5, 5, rank), (6, 6, rank), (7, 7, rank), (12, 12, rank), (20, 20, rank), (31, 31, rank)
            ]
        
        elif self.info_tuples_type == 923: # 10.9
            info_tuples = [
                (2, 3, rank), (4, 5, rank), (6, 7, rank), (8, 9, rank),
                (10, 11, rank), (10, 12, rank), (10, 13, rank), (10, 14, rank),
                (15, 16, rank), (15, 17, rank), (15, 18, rank), (15, 19, rank), (15, 20, rank), (15, 21, rank),
                (22, 23, rank), (22, 24, rank), (22, 25, rank), (22, 26, rank), (22, 27, rank), (22, 28, rank), (22, 29, rank),
                (0, 0, rank), (1, 1, rank), (2, 2, rank), (4, 4, rank), (6, 6, rank), (8, 8, rank), (10, 10, rank), (15, 15, rank), (22, 22, rank), (30, 30, rank), (31, 31, rank)
            ]
        
        elif self.info_tuples_type == 925: # start from middle -> 35 for (0.01 & 0.001)
            info_tuples = [(2, 3, rank), (4, 5, rank), (6, 7, rank),
                           (8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), (8, 13, rank), (8, 14, rank), (8, 15, rank), (8, 16, rank),
                           (17, 18, rank), (17, 19, rank), (17, 20, rank), (17, 21, rank), (17, 22, rank), (17, 23, rank), (17, 24, rank), (17, 25, rank), (17, 26, rank), (17, 27, rank), (17, 28, rank), (17, 29, rank), (17, 30, rank)]
        
        elif self.info_tuples_type == 935: # also 35, strange
            info_tuples = [(2, 3, rank), (4, 5, rank), 
                           (6, 7, rank), (6, 8, rank),
                           (9, 10, rank), (9, 11, rank), (9, 12, rank), (9, 13, rank), (9, 14, rank), (9, 15, rank), 
                           (16, 17, rank), (16, 18, rank), (16, 19, rank), (16, 20, rank), (16, 21, rank), (16, 22, rank), (16, 23, rank), (16, 24, rank), (16, 25, rank), (16, 26, rank), (16, 27, rank), (16, 28, rank), (16, 29, rank)]
        
        elif self.info_tuples_type == 945:
            info_tuples = [(0, 0, rank), (1, 1, rank),
                           (2, 3, rank), (4, 5, rank), (6, 7, rank), 
                           (8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), (8, 13, rank), (8, 14, rank), (8, 15, rank), (8, 16, rank),
                           (17, 18, rank), (17, 19, rank), (17, 20, rank), (17, 21, rank), (17, 22, rank), (17, 23, rank), (17, 24, rank), (17, 25, rank), (17, 26, rank), (17, 27, rank), (17, 28, rank), (17, 29, rank),
                           (30, 30, rank), (31, 31, rank)]
        
        elif self.info_tuples_type == 955: # 17
            info_tuples = [
                (2, 3, rank), (4, 5, rank), (6, 7, rank), 
                (8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), (8, 13, rank),
                (14, 15, rank), (14, 16, rank), (14, 17, rank), (14, 18, rank), (14, 19, rank), (14, 20, rank), (14, 21, rank), 
                (14, 22, rank), (14, 23, rank), (14, 24, rank), (14, 25, rank), (14, 26, rank), (14, 27, rank), (14, 28, rank), 
                (14, 29, rank)
            ]

        elif self.info_tuples_type == 965: # ft:1443 -> boom, don't know why. ft:8 -> 339, bad
            info_tuples = [
                (0, 0, rank), (1, 1, rank),
                (2, 3, rank), 
                (4, 5, rank), (4, 6, rank), (4, 7, rank),
                (8, 9, rank), (8, 10, rank), (8, 11, rank), (8, 12, rank), (8, 13, rank), (8, 14, rank), (8, 15, rank), (8, 16, rank),
                (17, 18, rank), (17, 19, rank), (17, 20, rank), (17, 21, rank), (17, 22, rank), (17, 23, rank), (17, 24, rank), (17, 25, rank), (17, 26, rank), (17, 27, rank), (17, 28, rank), (17, 29, rank),
                (30, 30, rank), (31, 31, rank)
            ]
        
        elif self.info_tuples_type == 970: # store 18 layers
            info_tuples = [
                 (6,7,rank), (8,9,rank), (10,11,rank), (12,13,rank), (12,14, rank), (15, 16, rank), (15, 17, rank), 
                 (18,19,rank), (18, 20, rank), (21, 22, rank), (21, 23, rank),
                 (24,25,rank), (26,27,rank), (28,29,rank)
            ]
        elif self.info_tuples_type == 971: # store 18 layers
            info_tuples = [
                 (4,5,rank), (6,7,rank), (8,9,rank), (10,11,rank), (10,12,rank), (13,14, rank), (13, 15, rank), (16, 17, rank), 
                 (16,18,rank), (19, 20, rank), (19, 21, rank), (22, 23, rank),
                 (24,25,rank), (26,27,rank), 
            ]
        elif self.info_tuples_type == 972: # store 18 layers
            info_tuples = [
                 (6,7,rank), (8,9,rank), (10,11,rank), (12,13,rank), (14, 15, rank), (16, 17, rank), (18,19,rank), (18, 20, rank), (21, 22, rank), (21, 23, rank),
                (24,25,rank), (24,26,rank), (27,28,rank), (27,29,rank)
            ]
        elif self.info_tuples_type == 973: # store 17 layers, seems very bad
            info_tuples = [
                 (2,3,rank),(2,4,rank), (5,6,rank),(5,7,rank),(8,9,rank),(8,10,rank), (11,12,rank),(11,13,rank),
                 (14,15,rank),(16,17,rank),(18,19,rank),(20,21,rank),(22,23,rank),(24,25,rank),(26,27,rank),
            ]
        
        elif self.info_tuples_type == 979: # 7.3
            info_tuples = [
                (1, 10, rank), (1, 11, rank), (1, 12, rank), (1, 13, rank), (1, 14, rank), (1, 15, rank),
                (3, 17, rank), (3, 18, rank), (3, 19, rank), (3, 20, rank), (3, 21, rank), (3, 22, rank), (3, 23, rank), (3, 24, rank), (3, 25, rank), (3, 26, rank), (3, 27, rank), 
            ]
        elif self.info_tuples_type == 978: # 8.6
            info_tuples = [
                (1, 7, rank), (1, 8, rank), (1, 9, rank), (1, 10, rank), (1, 11, rank),
                (3, 12, rank), (3, 13, rank), (3, 14, rank), (3, 15, rank), (3, 16, rank), (3, 17, rank), (3, 18, rank), (3, 19, rank), (3, 20, rank), (3, 21, rank), (3, 22, rank), (3, 23, rank), (3, 24, rank), (3, 25, rank), (3, 26, rank), (3, 27, rank), (3, 28, rank)
            ]
        elif self.info_tuples_type == 977: # 5.52 (warmup) -> 3.75 (0.1)
            info_tuples = [
                (4, 5, rank), (8, 9, rank), 
                (1, 10, rank), (1, 11, rank), (1, 12, rank), (1, 13, rank), (1, 14, rank), 
                (3, 18, rank), (3, 19, rank), (3, 20, rank), (3, 21, rank), 
                (3, 23, rank), (3, 24, rank), (3, 25, rank), (3, 26, rank), (3, 27, rank),
            ]

        elif self.info_tuples_type == 980: # store 8 layers
            info_tuples = [(2, 3, rank), (2, 4, rank), (2, 5, rank),
                           (6, 7, rank), (6, 8, rank), (6, 9, rank), (6, 10, rank),
                           (12, 13, rank), (12, 14, rank), (12, 15, rank), (12, 16, rank), (12, 18, rank), (12, 19, rank), (12, 20, rank), (12, 21, rank),  
                           (22, 23, rank), (22, 24, rank), (22, 25, rank), (22, 26, rank), (22, 27, rank), (22, 28, rank), (22, 29, rank), (22, 30, rank)
                           ]
        elif self.info_tuples_type == 981: # store 5 layers
            info_tuples = [(1, 2, rank), (1, 3, rank), (1, 4, rank), (1, 5, rank), (1, 6, rank), (1, 7, rank), (1, 8, rank), (1, 9, rank), 
                           (10, 11, rank), (10, 12, rank), (10, 13, rank), (10, 14, rank), (10, 15, rank), (10, 16, rank), (10, 17, rank), (10, 18, rank), (10, 19, rank),
                            (20, 21, rank), (20, 22, rank), (20, 23, rank), (20, 24, rank), (20, 25, rank), (20, 26, rank), (20, 27, rank), (20, 28, rank), (20, 29, rank), (20, 30, rank)
                           ]
        
        elif self.info_tuples_type == 988: # store 12 layers
            info_tuples = [ 
                (2, 3, rank), (2, 4, rank), (2, 5, rank), (2, 6, rank), (2, 7, rank), (2, 8, rank), (2, 9, rank), (2, 10, rank), (2, 11, rank),
                (12, 13, rank), (12, 14, rank), (12, 15, rank), (12, 16, rank), (12, 17, rank), (12, 18, rank),
                (19, 20, rank), (19, 21, rank), 
                (22, 23, rank), (24, 25, rank), (26, 27, rank), (28, 29, rank)
                ]
        
        
        elif self.info_tuples_type >= 800 and self.info_tuples_type < 900:
            # 8xx
            idx = self.info_tuples_type - 800
            # just replace this one layer
            info_tuples = [(idx, idx+1, rank)]
        
        elif self.info_tuples_type >= 10000:
            # 1abcd: ab is the reference layer, cd is the target layer, for example 10002, means (0, 2, rank)
            idx = self.info_tuples_type - 10000
            ref_layer = idx // 100
            target_layer = idx % 100
            info_tuples = [(ref_layer, target_layer, rank)]

            
            
        
        # elif self.info_tuples_type == 935:
            
        # in general, you should not include 30, which will significantly increase ppl. 
        
        return info_tuples
        
    def get_subtrain_splits(self, info_tuples, subtrain_type):
        # get the subtrain splits based on the info_tuples and subtrain_types
        # stages_splits[i] is the i-th stage's subtrain splits
        # stages_splits[i][j] is the j-th subtrain split in the i-th stage
        
        if subtrain_type == 0:
            # train each one separately in 1 stage
            stages_splits = [[[info_tuple] for info_tuple in info_tuples]]
        
        elif subtrain_type == 1:
            # each one except those info_tuple[0] == info_tuple[1]
            stages_splits = [[[info_tuple] for info_tuple in info_tuples if info_tuple[0] != info_tuple[1]]]
            
        
        elif subtrain_type == 100:
            n = len(info_tuples)
            # stage 1: train each 1 separately
            stage1_splits = [[info_tuple] for info_tuple in info_tuples]
            # stage 2: train each 2 separately. If there are 5 tuples, then 0 with 3, 1 with 4, 2 with 0 (5%5=0), [3 with 1, 4 with 2]
            step2 = math.ceil(n/2)
            stage2_splits = [[info_tuples[i], info_tuples[(i+step2)%n]] for i in range(0, n, 2)] # strange, it's better than range(0, math.ceil(n/2)) and similar for 4 and 8
            # stage 3: train each 4 separately
            step3 = math.ceil(n/4)
            stage3_splits = [[info_tuples[i], info_tuples[(i+step3)%n], info_tuples[(i+2*step3)%n], info_tuples[(i+3*step3)%n]] for i in range(0, n, 4)]
            # stage 4: train each 8 separately
            step4 = math.ceil(n/8)
            stage4_splits = [[info_tuples[i], info_tuples[(i+step4)%n], info_tuples[(i+2*step4)%n], info_tuples[(i+3*step4)%n], 
                              info_tuples[(i+4*step4)%n], info_tuples[(i+5*step4)%n], info_tuples[(i+6*step4)%n], info_tuples[(i+7*step4)%n]] for i in range(0, n, 8)]
            
            stages_splits = [stage1_splits, stage2_splits, stage3_splits, stage4_splits]
            # stages_splits = [stage1_splits, stage2_splits]
        
        elif subtrain_type == 200:
            n = len(info_tuples)
            # stage 1: train each 1 separately
            stage1_splits = [[info_tuple] for info_tuple in info_tuples]
            # stage 2: train each 2 separately. If there are 5 tuples, then 0 with 3, 1 with 4, 2 with 0 (5%5=0), [3 with 1, 4 with 2]
            step2 = math.ceil(n/2)
            stage2_splits = [[info_tuples[i], info_tuples[(i+step2)%n]] for i in range(0, step2)] # strange, it's better than range(0, math.ceil(n/2)) and similar for 4 and 8
            # stage 3: train each 4 separately
            
            stages_splits = [stage1_splits, stage2_splits]
        
            
        return stages_splits
        


class LoRALinear(nn.Module):
    # define the LoRA linear layer for fixed weight and trainable LoRA components
    
    def __init__(self, weight, bias=None, rank=5, scale=1.0, forward_type = 0, device='cpu'):
        super(LoRALinear, self).__init__()
        
        # print(f'None for weight: {weight is None}')
        
        self.weight = nn.Parameter(weight, requires_grad=False).to(device) # frozen weight
        self.dim0 = weight.size(0)
        self.dim1 = weight.size(1)
        
        self.device = device
        
        if bias is not None:
            self.bias = nn.Parameter(bias, requires_grad=False).to(device)
        else:
            self.bias = None
        
        # self.type = forward_type % 10000 # determine the type of LoRA
        self.type = forward_type # determine the type of LoRA
        
        # lora components -> this kind of writing has no bug. The memory issue is really random
        # if self.type not in [5, 6]:
        #     self.u, self.v = self.init_loras(rank, scale, self.dim0, self.dim1, self.device)
        
        self.u, self.v = self.init_loras(rank, scale, self.dim0, self.dim1, self.device)
            
        # print(f'u.shape = {self.u.shape}, requires_grad = {self.u.requires_grad}, device = {self.u.device}, device for weight = {self.weight.device}')
        
        # left matrix and right matrix
        if self.type in [1, 9, 5, 7, 3]:
            if self.type == 5:
                self.lc, self.ld = self.init_loras(rank, scale*1000, self.dim0, self.dim0, self.device) # otherwise, there is no gradient
            else:
                self.lc, self.ld = self.init_loras(rank, scale, self.dim0, self.dim0, self.device)
                        
        if self.type in [2, 9, 6, 7, 4]:
            if self.type == 6:
                self.rc, self.rd = self.init_loras(rank, scale*1000, self.dim1, self.dim1, self.device) # otherwise, there is no gradient
            else:
                self.rc, self.rd = self.init_loras(rank, scale, self.dim1, self.dim1, self.device)
        
        if self.type in [-1]:
            self.yy, self.zz = self.init_mul(rank, self.dim0, self.dim1, self.device)
        
        # if self.type in [7, 8, 3, 4]:
        if self.type in [7, 8, 3, 4, 5, 6]:
            self.alpha = self.init_alpha(self.device)
            print_debug(f'alpha: {self.alpha.size()}')
        
        # print the current keys and shapes (if none -> print none)
        for key in self.state_dict().keys():
            print_debug(f'key = {key}, {self.state_dict()[key]}')
            

    
    
    def init_loras(self, rank, scale, dim1, dim2, device='cpu'):
        # initialize the lora components
        std = 0.00001 * scale
        u = nn.Parameter(torch.zeros(dim1, rank).to(device), requires_grad=True) # remember, not use something like nn.Parameter().to(device), you should first nn.Parameter() and then to(device), otherwise the parameter will not be registered
        v = nn.Parameter(torch.zeros(rank, dim2).to(device), requires_grad=True)
        nn.init.normal_(u, mean=0.0, std=std)
        nn.init.normal_(v, mean=0.0, std=std)
        
        return u, v
    
    def init_alpha(self, device='cpu'):
        # initialize the alpha for type 7 and 8
        alpha = nn.Parameter(torch.ones(1).to(device), requires_grad=True)
        return alpha
    
    def init_mul(self, rank, dim1, dim2, device='cpu'):
        std = 0.00001
        # yy is dim1 x rank, zz is rank x dim2, all elements are initialized to be 1/sqrt(rank)
        yy = nn.Parameter(torch.zeros(dim1, rank).to(device), requires_grad=True)
        zz = nn.Parameter(torch.zeros(rank, dim2).to(device), requires_grad=True)
        nn.init.normal_(yy, mean=1/np.sqrt(rank), std=std)
        nn.init.normal_(zz, mean=1/np.sqrt(rank), std=std)
        print(f'yy.shape = {yy.shape}, zz.shape = {zz.shape}, yy[:3, :3] = {yy[:3, :3]}')
        return yy, zz
    
    # def init_loras(self, rank, scale, dim1, dim2):
    #     # initialize the lora components
    #     std = 0.00001 * scale
    #     u = nn.Parameter(torch.zeros(dim1, rank), requires_grad=True) # remember, not use something like nn.Parameter().to(device), you should first nn.Parameter() and then to(device), otherwise the parameter will not be registered
    #     v = nn.Parameter(torch.zeros(rank, dim2), requires_grad=True)
    #     nn.init.normal_(u, mean=0.0, std=std)
    #     nn.init.normal_(v, mean=0.0, std=std)
        
    #     return u, v
    
    # def init_alpha(self):
    #     # initialize the alpha for type 7 and 8
    #     alpha = nn.Parameter(torch.ones(1), requires_grad=True)
    #     return alpha
    
    def get_new_weight(self):
        # Type 0
        if self.type == 0:
            # print(f'weight.device = {self.weight.device}, u.device = {self.u.device}, v.device = {self.v.device}')
            new_weight = self.weight + self.u @ self.v
        elif self.type == 1:
            new_weight = self.weight + self.lc @ (self.ld @ self.weight) + self.u @ self.v
        elif self.type == 2:
            new_weight = self.weight + self.weight @ self.rc @ self.rd + self.u @ self.v
        
            
        # yd's idea -> not work for different loras init. the eval loss has almost no change. I think the reason is that the lora ranks are too low.
        elif self.type == 5:
            new_weight = self.alpha * self.lc @ (self.ld @ self.weight) + self.u @ self.v
        elif self.type == 6:
            new_weight = self.alpha * self.weight @ self.rc @ self.rd + self.u @ self.v
        
        # elif self.type == 5:
        #     new_weight = self.alpha * self.weight + self.lc @ (self.ld @ self.weight)
        # elif self.type == 6:
        #     new_weight = self.alpha * self.weight + self.weight @ self.rc @ self.rd
        
        # alpha
        elif self.type == 7:
            new_weight = self.alpha * self.weight + self.lc @ (self.ld @ self.weight) + self.weight @ self.rc @ self.rd + self.u @ self.v
        elif self.type == 8:
            new_weight = self.alpha * self.weight + self.u @ self.v
            
        elif self.type == 3:
            new_weight = self.alpha * self.weight + self.lc @ (self.ld @ self.weight) + self.u @ self.v
        elif self.type == 4:
            new_weight = self.alpha * self.weight + self.weight @ self.rc @ self.rd + self.u @ self.v
            
        elif self.type == 9:    
            new_weight = self.weight + self.lc @ (self.ld @ self.weight) + self.weight @ self.rc @ self.rd + self.u @ self.v
            
        elif self.type == -1:
            new_weight = (self.yy @ self.zz) * self.weight + self.u @ self.v
        
        return new_weight
    
        
    def forward(self, x):
        
        new_weight = self.get_new_weight()
        if self.bias is not None:
            return x @ new_weight.t() + self.bias
        else:
            return x @ new_weight.t()
    

class LoRAMLP(nn.Module):
    # define the LoRA MLP model using LoRALinear
    
    def __init__(self, gate_weights, up_weights, down_weights, rank=5, scale=1.0, forward_type=0, device='cpu'):
        super(LoRAMLP, self).__init__()
        
        # assert (gate_weights is not None) and (up_weights is not None) and (down_weights is not None), 'LoRAMLP: weights are None'
        
        
        # create the gate, up, down weights 
        self.sub_type = 0
        
        if forward_type >= 10000:
            self.sub_type = forward_type // 10000
            forward_type = forward_type % 10000
            
        if forward_type >= 1000:
            # hundreds, tens, ones -> gate, up, down
            tmp_forward_type = forward_type - 1000
            forward_type_gate = tmp_forward_type // 100
            forward_type_up = (tmp_forward_type % 100) // 10
            forward_type_down = tmp_forward_type % 10
        else:
            forward_type_gate = forward_type
            forward_type_up = forward_type
            forward_type_down = forward_type
            
        # g1: 1556: r=163
        # g2: 1665: r=259
        
        print(f'in LoRAMLP, device = {device}, ft_gate = {forward_type_gate}, ft_up = {forward_type_up}, ft_down = {forward_type_down}')
        
        self.gate = LoRALinear(gate_weights, rank=rank, scale=scale, forward_type=forward_type_gate, device=device)
        
        # still use the original gate function
        # self.gate = nn.Linear(gate_weights.size(0), gate_weights.size(1), bias=False)
        # self.gate.weight = nn.Parameter(gate_weights, requires_grad=False)
        # self.gate = self.gate.to(device)
        # print('!!!Another large difference here, need change!!')
        
        self.up = LoRALinear(up_weights, rank=rank, scale=scale, forward_type=forward_type_up, device=device)
        self.down = LoRALinear(down_weights, rank=rank, scale=scale, forward_type=forward_type_down, device=device)
        
        self.device = device
        
        if self.sub_type == 1:
            self.vec = torch.nn.Parameter(torch.zeros(1, down_weights.shape[0]).to(device), requires_grad=True)
            nn.init.normal_(self.vec, mean=0.0, std=0.00001)
            print(f'~~~~~~~~~~~~ vec: {self.vec.size()} ~~~~~~~~~~~~~~')
            
        
    def forward(self, x):
        gate = F.silu(self.gate(x))
        up = self.up(x)
        down = self.down(gate * up)
        
        if self.sub_type == 1:
            output = down + x * self.vec
        else:
            output = down
        
        return output

# model IO

def get_pth_loramlp_name(ref_layer, target_layer, note=''):
    return f'ref{ref_layer}_tar{target_layer}_{note}.pth'

def store_loramlp_util(loramlp, ref_layer, target_layer, store_model_dir, note=''):
    # make sure the store_model_dir exists
    os.makedirs(store_model_dir, exist_ok=True)
    # store the LoRAMLP model
    stored_path = get_loramlp_paths(ref_layer, target_layer, store_model_dir, note)
    torch.save(loramlp.state_dict(), stored_path)
    print(f'=> Store the LoRAMLP model: ref {ref_layer}, target {target_layer} to {stored_path}')


def load_lora_model_util(gate_weights, up_weights, down_weights, rank, forward_type, device, ref_layer, target_layer, store_model_dir, note='', strict=True):
    # load the LoRAMLP model
    file_path = get_loramlp_paths(ref_layer, target_layer, store_model_dir, note)
    if os.path.exists(file_path):
        loramlp = LoRAMLP(gate_weights=gate_weights, up_weights=up_weights, down_weights=down_weights, rank=rank, forward_type=forward_type, device=device)
        loramlp_state_dict = torch.load(file_path)
        loramlp.load_state_dict(loramlp_state_dict, strict=strict)
        print(f'=> Load the LoRAMLP model: ref {ref_layer}, target {target_layer} from {file_path}')
        return loramlp
    else:
        print(f'!!! ### Cannot find the stored LoRAMLP model: ref {ref_layer}, target {target_layer} from {file_path}')
        return None


def get_loramlp_paths(ref_layer, target_layer, store_model_dir, note=''):
    file_path = os.path.join(store_model_dir, get_pth_loramlp_name(ref_layer, target_layer, note))
    return file_path



@torch.no_grad()
def visualize_model_named_parameters(model, print_grad=False, prefix=''):
    print(f'=================== Trainable Parameters =====================')
    for name, param in model.named_parameters():
        if param.requires_grad:
            if print_grad and param.grad is not None:
                print('{}--{}: norm: {:.4g}, grad norm: {:.4g}'.format(prefix, name, torch.norm(param).item(), torch.norm(param.grad).item()))
            else:
                print('{}--{}: norm: {:.4g}, param size : {}'.format(prefix, name, torch.norm(param).item(), param.size()))
    
    # print the norm of the first 3 non-gradient parameter
    count = 0
    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(f'{prefix}--{name}: norm: {torch.norm(param).item()}, random 20 columns tensor norm: {torch.norm(param[:, :20]).item()}') 
            # stu init--gate.weight: norm: 130.26380920410156, random 20 columns tensor norm: 9.116674423217773
            count += 1
            if count == 3:
                break


def get_online_warmup_postfix(subset_size, subset_idx):
    return f'{(subset_idx + 1) * subset_size}'


def find_max_possible_postfix_num(store_model_dir, ref_layer, target_layer, subset_size=1024, subset_num=10, use_postfix=0):
    max_subidx = subset_num - 1
    result = None
    
    if use_postfix == 0:
        result = ''
    else:
        for subidx in range(max_subidx, -1, -1):
            postfix = get_online_warmup_postfix(subset_size, subidx)
            file_path = get_loramlp_paths(ref_layer, target_layer, store_model_dir, postfix)
            if os.path.exists(file_path):
                result = postfix
                break
    
    print(f'find_max_possible_postfix_num: {result}, original max_subidx = {max_subidx}')
    return result
    
    

class LoRALlama(nn.Module):
    # define the LoRALlama model, which replace some MLP layers in the llama model with LoRAMLP
    # info_tuples: [(ref_layer, target_layer, rank), ...], for more flexible usage
    
    def __init__(self, llama_model, info_tuples, store_model_dir='', forward_type=0, use_stored=0, subset_size=1024, subset_num=10, use_postfix=0):
        super(LoRALlama, self).__init__()
        self.llama_model = llama_model
        self.info_tuples = info_tuples
        self.store_model_dir = store_model_dir
        self.forward_type = forward_type
        self.config = self.llama_model.config
        

        # first freeze the base model
        for name, param in self.llama_model.named_parameters():
            if 'weight' in name:
                param.requires_grad = False
        
        # init the LoRAMLP with new lora components
        print(f'!!! use_stored = {use_stored} !!!')
        self.replace_loramlp(use_stored=use_stored, subset_size=subset_size, subset_num=subset_num, use_postfix=use_postfix)
        
    
    def replace_original(self, original_model, device=None):
        # replace the original model with the new weight
        for info_tuple in self.info_tuples:
            _, target_layer, _ = info_tuple
            
            # get the new weight for each layer
            new_gate_data = self.llama_model.model.layers[target_layer].mlp.gate.get_new_weight().data
            new_up_data = self.llama_model.model.layers[target_layer].mlp.up.get_new_weight().data
            new_down_data = self.llama_model.model.layers[target_layer].mlp.down.get_new_weight().data
            
            if device is None:
                device_t = new_gate_data.device
            else:
                device_t = device
            
            # get the original device
            original_model.model.layers[target_layer].mlp.gate_proj.weight.data = new_gate_data.to(device_t)
            original_model.model.layers[target_layer].mlp.up_proj.weight.data = new_up_data.to(device_t)
            original_model.model.layers[target_layer].mlp.down_proj.weight.data = new_down_data.to(device_t)
            
            print(f'renew {target_layer}')
            
        return original_model
    
    def replace_original_all(self, original_model, device=None):
        # replace the original model with the new weight, both mlp and self_attn
        for info_tuple in self.info_tuples:
            ref_layer, target_layer, _ = info_tuple
            
            # get the new weight for each layer
            new_gate_data = self.llama_model.model.layers[target_layer].mlp.gate.get_new_weight().data
            new_up_data = self.llama_model.model.layers[target_layer].mlp.up.get_new_weight().data
            new_down_data = self.llama_model.model.layers[target_layer].mlp.down.get_new_weight().data
            
            if device is None:
                device_t = new_gate_data.device
            else:
                device_t = device
                
            # self_attn
            new_q = self.llama_model.model.layers[ref_layer].self_attn.q_proj.weight
            new_k = self.llama_model.model.layers[ref_layer].self_attn.k_proj.weight
            new_v = self.llama_model.model.layers[ref_layer].self_attn.v_proj.weight
            new_o = self.llama_model.model.layers[ref_layer].self_attn.o_proj.weight
            
            with torch.no_grad():
                # get the original device
                original_model.model.layers[target_layer].mlp.gate_proj.weight.data = new_gate_data.to(device_t)
                original_model.model.layers[target_layer].mlp.up_proj.weight.data = new_up_data.to(device_t)
                original_model.model.layers[target_layer].mlp.down_proj.weight.data = new_down_data.to(device_t)
                
                # replace the self_attn directly using the ref_layer's self_attn
                # use copy_ to avoid the inplace operation
                original_model.model.layers[target_layer].self_attn.q_proj.weight.data.copy_(new_q.to(device_t))
                original_model.model.layers[target_layer].self_attn.k_proj.weight.data.copy_(new_k.to(device_t))
                original_model.model.layers[target_layer].self_attn.v_proj.weight.data.copy_(new_v.to(device_t))
                original_model.model.layers[target_layer].self_attn.o_proj.weight.data.copy_(new_o.to(device_t))
                
            
            print(f'renew {target_layer}, both mlp and self_attn')
            
        return original_model
        
                    
    def replace_loramlp(self, use_stored=0, subset_size=1024, subset_num=10, use_postfix=0):
        # replace the original mlp layer with loramlp. 
        for info_tuple in self.info_tuples:
            ref_layer, target_layer, rank = info_tuple
            
            # create the LoRAMLP using refer layer
            ref_gate_weights = self.llama_model.model.layers[ref_layer].mlp.gate_proj.weight
            ref_up_weights = self.llama_model.model.layers[ref_layer].mlp.up_proj.weight
            ref_down_weights = self.llama_model.model.layers[ref_layer].mlp.down_proj.weight
            # loramlp's device should be the same as target layer's device
            target_device = self.llama_model.model.layers[target_layer].mlp.gate_proj.weight.device
            loramlp = LoRAMLP(gate_weights=ref_gate_weights, up_weights=ref_up_weights, down_weights=ref_down_weights, rank=rank, scale=1.0, forward_type=self.forward_type, device=target_device)
            
            # use the stored LoRAMLP to replace the original mlp if use_stored == 1 and can successfully load the stored LoRAMLP
            if use_stored:
                # file_path = os.path.join(self.store_model_dir, get_pth_loramlp_name(ref_layer, target_layer))
                postfix = find_max_possible_postfix_num(self.store_model_dir, ref_layer, target_layer, subset_size, subset_num, use_postfix)
                if postfix is not None:
                    file_path = get_loramlp_paths(ref_layer, target_layer, self.store_model_dir, postfix)
                else:
                    raise ValueError(f'Cannot find the stored LoRAMLP model: ref {ref_layer}, target {target_layer} from {self.store_model_dir}, for subset_size = {subset_size}, subset_num = {subset_num}, use_postfix = {use_postfix}')
                
                print(f'=> Load the stored LoRAMLP model: ref {ref_layer}, target {target_layer} from {file_path}')
                if os.path.exists(file_path):
                    loramlp_state_dict = torch.load(file_path)
                    loramlp.load_state_dict(loramlp_state_dict, strict=False)
                else:
                    print(f'!!! ### Cannot find the stored LoRAMLP model: ref {ref_layer}, target {target_layer} from {file_path}')
                
                # loramlp = load_lora_model_util(ref_gate_weights, ref_up_weights, ref_down_weights, rank, self.forward_type, target_device, ref_layer, target_layer, self.store_model_dir, note)
            
            # for debug
            print(f'replace layer {target_layer} with layer {ref_layer}, rank: {rank}, device = {target_device} | use_stored: {use_stored}, loramlp is None: {loramlp is None}')
            
            # replace target layer with LoRAMLP
            self.llama_model.model.layers[target_layer].mlp = loramlp
            
    def store_loramlp(self): # it is not optimized!
        # store the LoRAMLP model
        for info_tuple in self.info_tuples:
            ref_layer, target_layer, _ = info_tuple
            loramlp = self.llama_model.model.layers[target_layer].mlp
            store_loramlp_util(loramlp, ref_layer, target_layer, self.store_model_dir)
    

        
        
    def visualize_trainable_params(self, print_grad=False):
        print(f'=================== Trainable Parameters =====================')
        for name, param in self.named_parameters():
            if param.requires_grad:
                if print_grad:
                    print(f'{name}: norm: {torch.norm(param).item()}, grad norm: {torch.norm(param.grad).item()}')
                else:
                    print(f'{name}: norm: {torch.norm(param).item()}')
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.llama_model(input_ids, attention_mask=attention_mask, labels=labels)
    
    def generate(self, input_ids, attention_mask=None, max_new_tokens=200, do_sample=True, **model_kwargs):
        # just not support max_length to avoid the error
        return self.llama_model.generate(input_ids, attention_mask=attention_mask, max_length=max_new_tokens, do_sample=do_sample, **model_kwargs)
        

    def store_whole_llama_model(self, ori_model_dir, tokenizer, if_remove_current_model=1, note='', if_replace_self_attn=0): # for evaluation
        # store the whole llama model using save_pretrained
        whole_model_dir = os.path.join(self.store_model_dir, f'whole_model{note}')
        os.makedirs(whole_model_dir, exist_ok=True)
        
        ####### store the whole model
        # create an original model on cpu
        
        new_model = LlamaForCausalLM.from_pretrained(ori_model_dir)
        
        if if_remove_current_model:
            if if_replace_self_attn:
                modified_model = self.replace_original_all(new_model)
            else:
                modified_model = self.replace_original(new_model)
                
            device = 'cuda:0'
            # move the current model to cpu
            self.llama_model.to('cpu')
            torch.cuda.empty_cache()
            modified_model.to(device)
            # double check the output
            visualize_output(modified_model, device, tokenizer)
        else:
            if if_replace_self_attn:
                modified_model = self.replace_original_all(new_model, device='cpu')
            else:
                modified_model = self.replace_original(new_model, device='cpu')
        
        
        print(f'=> Store the whole model to {whole_model_dir}, if store self_attn: {if_replace_self_attn}')
        modified_model.save_pretrained(whole_model_dir)
        
        # copy tokenizer from the original model directory
        os.system(f'cp -r {ori_model_dir}/tokenizer* {whole_model_dir}/')
        
        
        ls = os.listdir(whole_model_dir)
        # visualize the stored files
        for l in ls:
            print(l)
            

            
    