import json
from pathlib import Path
from shutil import copyfile

from safetensors.torch import load_file
import torch
from tap import Tap

class Args(Tap):
    output_dir: str = 'output'
    # model_size = 'mamba2-370m'
    model_size = 'mamba2-130m'
    # run_name: str = 'mamba2-130m-lr0.001_T2048_B16_GA4_P1_RD0_240620075448'
    # run_name = 'mamba2-130m-lr0.0005_T2048_B16_GA4_P1_RD0_240620152422/ckpt_84400'
    # run_name = 'mamba2-370m_lr0.001_T8192_B1_GA4_P4_SR1_RD0_240625134012'
    run_name = 'mamba2-130m_lr0.0005_T512_B64_GA2_P1_SR1_RD0_240630011625'
    output_name = None
    # run_name: str = 'L8192-P2-GA4-B1-R1-lr1e-4'
    # run_name: str = 'L8192-P4-GA4-B1-R0-lr1e-4'
    output_name = None
    step: int = 10000
    save_dir: str = 'ckpts'


def copy_ckpt(src_dir: Path, dst_dir: Path):
    print(f"Loading args and config")
    cmd_args = json.load(open(src_dir.parent / 'args.json', 'r'))
    # pretrained_path = Path(cmd_args['pretrained_path'])
    print(f"Loading checkpoint from {src_dir}")
    state_dict = load_file(src_dir / 'model.safetensors')
    state_dict['lm_head.weight'] = state_dict['backbone.embedding.weight']
    
    # Save checkpoint
    dst_dir.mkdir(exist_ok=True, parents=True)
    print(f"Saving checkpoint to {dst_dir}")
    torch.save(state_dict, dst_dir / 'pytorch_model.bin')
    copyfile(src_dir.parent.parent / 'config.json', dst_dir / 'config.json')


def main():
    print("MAIN")
    src_dir = Path("output")
    dst_dir = Path("ckpts")
    # lr = "0.0005"
    # model_size = "mamba2-370m"
    src_names = [
        # "mamba2-370m_lr0.0005_T512_B16_GA8_P1_SR16_RD0_RI1_240805042117",
        # "mamba2-370m_lr0.0005_T1024_B8_GA8_P1_SR8_RD0_RI1_240805040355",
        # "mamba2-370m_lr0.0005_T1024_B4_GA8_P2_SR4_RD0_RI1_240805040337",
        # "mamba2-370m_lr0.0005_T1024_B2_GA8_P4_SR2_RD0_RI1_240805040321",
        # "mamba2-370m_lr0.0005_T1024_B1_GA8_P8_SR1_RD0_RI1_240805040306",
        # "mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR1_RD0_RI0",
        # "mamba2-370m_lr0.0005_T8192_B1_GA8_P1_SR2_RD0_RI0",
        # "mamba2-370m_lr0.0005_T8192_B1_GA8_P1_SR4_RD0_RI0",
        "mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR16_RD0_RI0",
        "mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR8_RD0_RI0",
        "mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR4_RD0_RI0",
        "mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR2_RD0_RI0",
        "mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR1_RD0_RI0",
    ]
    
    src_names = [
        "mamba2-130m_lr0.0005_T1024_B64_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T2048_B32_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T4096_B16_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T8192_B8_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T16384_B4_GA1_P1_SR1_RD0_RI0",
    ]
    # src_names = [
    #     "mamba2-24-1024_lr0.0003_T8192_B4_GA2_P1_SR1_RD0_RI1_240805201408",
    #     "mamba2-36-768_lr0.0003_T8192_B4_GA2_P1_SR1_RD0_RI1_240805201509",
    #     "mamba2-48-768_lr0.0003_T8192_B4_GA2_P1_SR1_RD0_RI1_240805201528",
    #     "mamba2-36-1024_lr0.0003_T8192_B4_GA2_P1_SR1_RD0_RI1_240805201420",
    # ]
    
    # src_names = [
    #     # "mamba2-370m_lr0.0005_T512_B32_GA4_P1_SR1_RD0_240803123919",
    #     # "mamba2-130m_lr0.0005_T128_B128_GA4_P1_SR1_RD0_240803235121",
    #     # "mamba2-130m_lr0.0005_T256_B64_GA4_P1_SR1_RD0_240804042010",
    #     # "mamba2-130m_lr0.0005_T512_B32_GA4_P1_SR1_RD0_240804042033",
    #     # "mamba2-130m_lr0.0005_T2048_B8_GA4_P1_SR1_RD0_240804042053",
        
    #     # "mamba2-370m_lr0.0005_T512_B32_GA4_P1_SR1_RD0_240803123919",
    # ]
    
    src_names = [
        "mamba2-1.3b_lr0.0002_T8192_B1_GA4_P4_SR8_RD0_RI0",
        "mamba2-1.3b_lr0.0002_T8192_B1_GA4_P4_SR4_RD0_RI0",
        "mamba2-1.3b_lr0.0002_T8192_B1_GA2_P4_SR8_RD0_RI0",
        "mamba2-1.3b_lr0.0002_T8192_B1_GA2_P4_SR4_RD0_RI0",
        "mamba2-1.3b_lr0.0002_T8192_B1_GA2_P4_SR2_RD0_RI0",
        "mamba2-1.3b_lr0.0002_T8192_B1_GA2_P4_SR1_RD0_RI0",
        "mamba2-1.3b_lr0.0001_T8192_B1_GA4_P4_SR16_RD0_RI0",
        "mamba2-1.3b_lr0.0001_T8192_B1_GA4_P4_SR4_RD0_RI0",
    ]
    
    src_names = [
        "mamba2-130m_lr0.0005_T1024_B64_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T2048_B32_GcA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T4096_B16_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T8192_B8_GA1_P1_SR1_RD0_RI0",
        "mamba2-130m_lr0.0005_T16384_B4_GA1_P1_SR1_RD0_RI0",
    ]
    
    src_names = [
        "mamba2-370m_lr0.0005_T8192_B1_GA1_P8_SR1_RD0_RI0",
        "mamba2-370m_lr0.0005_T8192_B1_GA8_P1_SR2_RD0_RI0",
        "mamba2-370m_lr0.0005_T8192_B1_GA8_P1_SR4_RD0_RI0",
        "mamba2-370m_lr0.0005_T8192_B2_GA1_P4_SR1_RD0_RI0",
        "mamba2-370m_lr0.0005_T8192_B4_GA1_P2_SR1_RD0_RI0",
    ]
    
    src_names = [
        "mamba2-130m_lr0.0005_T2048_B32_GA1_P1_SR8_RD0_RI0",
        # "mamba2-130m_lr0.0005_T2048_B32_GA1_P1_SR16_RD0_RI0",
        # "mamba2-130m_lr0.0005_T2048_B32_GA1_P1_SR32_RD0_RI0",
        # "mamba2-130m_lr0.0005_T4096_B16_GA1_P1_SR8_RD0_RI0",
        # "mamba2-130m_lr0.0005_T4096_B16_GA1_P1_SR16_RD0_RI0",
        # "mamba2-130m_lr0.0005_T4096_B16_GA1_P1_SR4_RD0_RI0",
    ]
    
    
    src_names = [
        # "mamba2-370m_lr0.0005_T8192_B1_GA1_P8_SR1_RD0_RI0",
        # "mamba2-370m_lr0.0005_T8192_B1_GA1_P8_SR2_RD0_RI0",
        # "mamba2-370m_lr0.0005_T8192_B1_GA1_P8_SR4_RD0_RI0",
        # "mamba2-370m_lr0.0005_T8192_B1_GA1_P8_SR16_RD0_RI0",
        # "mamba2-36-768_lr0.0005_T8192_B4_GA1_P2_SR4_RD0_RI1",
        # "mamba2-36-768_lr0.0005_T8192_B4_GA1_P2_SR16_RD0_RI1",
        # "mamba2-36-768_lr0.0005_T8192_B4_GA1_P2_SR1_RD0_RI1",
        # "mamba2-370m_lr0.0005_T8192_B4_GA1_P2_SR16_RD0_RI",
        # "mamba2-370m_lr0.0005_T8192_B4_GA1_P2_SR4_RD0_RI0",
        # "mamba2-512-8_lr0.0005_T8192_B8_GA1_P1_SR1_RD0_RI1",
        "mamba2-130m_lr0.0005_T8192_B8_GA1_P1_SR1_RD0_RI1",
    ]
    
    src_names = [
        'mamba2-370m_lr0.0005_T8192_B1_GA1_P8_SR4_RD0_RI0',
    ]
    
    src_names = [
        'mamba2-512-12_lr0.0005_T8192_B8_GA1_P1_SR1_RD0_RI1',
        'mamba2-768-12_lr0.0005_T65536_B1_GA1_P1_SR1_RD0_RI1'
    ]
    
    src_names = [
        "mamba2-370m_lr0.0005_T1024_B1_GA8_P8_SR1_RD0_RI1",
    ]
    
    src_names = [
        # 'mamba2-130m_lr0.0005_T65536_B1_GA1_P1_SR1_RD0_RI1',
        'mamba2-780m_lr0.0005_T8192_B1_GA1_P8_SR64_RD0_RI0'
    ]
    
    for src_name in src_names:
        print("==================================================")
        print(f"Handling {src_name = }")
        model_name = src_name.split('_')[0]
        lr_str = src_name.split('_')[1]
        lr = float(lr_str[2:])
        timestamp = src_name.split("_")[-1]
        hparam = '_'.join(src_name.split('_')[2:-1])
        
        # run_name = f"{model_name}_{hparam}"
        this_src_dir = Path(src_dir, model_name, src_name)
        # hparam = "_".join(hparam.split("_")[:-1])
        dst_name = f"{hparam}_{lr_str}"
        # for step in range(20000, 200000, 20000):
        # for step in range(80000, 4300000, 80000):
        # for step in range(10000, 100000, 10000):
        # for step in range(10000, 35000, 5000):
        # for step in [100000]:
        # for step in [10000, 20000, 30000, 40000]:
        # for step in range(30000, 60000, 10000):
        for step in range(20000, 60000, 2000):
            ckpt_dir = this_src_dir / f"ckpt_{step}"
            if not ckpt_dir.exists():
                print(f"{ckpt_dir} does not exist, skipping...")
                continue
            
            this_dst_dir = Path(dst_dir, model_name, dst_name, f"ckpt_{step}")
            print("==================================================")
            print(f"Copying checkpoint from {ckpt_dir} to {this_dst_dir}")
            copy_ckpt(ckpt_dir, this_dst_dir)
    

if __name__ == "__main__":
    main()
