import os
import sys
sys.path.append(os.getcwd())
import warnings

from string import Template
from datetime import datetime
import socket
import time
import os
from gridsearcher import GridSearcher
from tqdm import tqdm
import argparse
import psutil
import math

def get_arg_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument('--wait_pids', nargs='+', default=None, required=False)
    parser.add_argument('--wait_secs', type=int, default=None, required=False)
    args = parser.parse_args()

    if args.wait_pids is not None:
        while any([psutil.pid_exists(int(pid)) for pid in args.wait_pids]):
            print(f'waiting for processes {args.wait_pids} to end...')
            time.sleep(60)
    elif args.wait_secs is not None:
        print(f'Waiting {args.wait_secs} seconds')
        for _ in tqdm(args.wait_secs):
            time.sleep(1)

def add_model_params(model_size, tokens_ratio, batch_size, acc_steps, seq_len):
    data = {}
    if model_size == 30:
        data['lr'] = 0.0012
        data['n_layer'] = 6
        data['n_embd'] = 640
        data['n_head'] = 5
    elif model_size == 350: # from Dion
        data['lr'] = 0.01
        data['n_layer'] = 24
        data['n_embd'] = 1024
        data['n_head'] = 32
    elif model_size == 800:
        data['lr'] = 0.000075
        data['n_layer'] = 16
        data['n_embd'] = 2048
        data['n_head'] = 16
    elif model_size == 1300: # from Dion
        data['lr'] = 0.01
        data['n_layer'] = 24
        data['n_embd'] = 2048
        data['n_head'] = 32
    else:
        raise RuntimeError(f'Unsupported model size {model_size}')

    tokens = model_size * 1_000_000 * tokens_ratio
    iterations = math.ceil(tokens / (batch_size * acc_steps * seq_len))
    data['iterations'] = iterations
    return data

def main(MODEL_SIZE, OPTIM, TOKENS_RATIO, gpus, wandb_project, dist=False, max_jobs=1, param_dict=None):
    state_finished = './exps/state.finished'
    if os.path.isfile(state_finished):
        os.remove(state_finished)

    # now = datetime.now().strftime('_%Y-%m-%d_%H-%M-%S')
    gs = GridSearcher(script='main.py', defaults=dict())

    gs.add_param('distributed_backend', 'nccl')
    gs.add_param('latest_ckpt_interval', 1000)
    gs.add_param('compile', True)
    gs.add_param('model', 'llama')

    gs.add_param('vocab_size', 32000)

    gs.add_param('batch_size', 64) # use 32 for 1.3B
    gs.add_param('acc_steps', 8) # use 16 for 1.3B
    gs.add_param('dataset', 'c4')
    gs.add_param('datasets_dir', './datasets/')

    data = add_model_params(MODEL_SIZE, TOKENS_RATIO, batch_size=getattr(gs, '_batch_size'), acc_steps=getattr(gs, '_acc_steps'), seq_len=512)
    for k, v in data.items(): # sets values for n_layer, n_embd, n_head, lr, iterations, warmup_steps
        gs.add_param(k, v)

    gs.add_param('opt', OPTIM)
    gs.add_param('wandb', True)

    wandb_group = '${opt}'
    wandb_job_type = 'lr=${lr}'

    if OPTIM == 'adamw':
        pass
    elif 'dct-adamw' in OPTIM:
        wandb_group += '-r=${lowrank_rank}-p=${lowrank_proj}-d=${lowrank_distributed}-ef=${lowrank_use_ef}-q=${lowrank_q_ef}-rs=${lowrank_rotate_states}-ug=${lowrank_upd_gap}'
    elif OPTIM == 'ldadamw':
        wandb_group += '-r=${lowrank_rank}-ef=${lowrank_use_ef}-p=${lowrank_proj}'
    elif OPTIM == 'galoreadamw':
        wandb_group += '-p=${lowrank_proj}-r=${lowrank_rank}-ug=${lowrank_upd_gap}'
    elif 'frugal' in OPTIM:
        wandb_group += '-p=${lowrank_proj}-r=${lowrank_rank}-ug=${lowrank_upd_gap}-rs=${lowrank_rotate_states}'
    elif 'fira' in OPTIM:
        wandb_group += '-p=${lowrank_proj}-r=${lowrank_rank}-ug=${lowrank_upd_gap}'
    elif OPTIM == 'apollo':
        wandb_group += '-p=${lowrank_proj}-r=${lowrank_rank}-ug=${lowrank_upd_gap}'
    elif OPTIM == 'trion':
        wandb_group += '-r=${lowrank_rank}-ns=${muon_ns_type}-scaling=${scaling_type}-makhoul=${use_makhoul}'
    else:
        raise RuntimeError(f'Optimizer {OPTIM} is currently not supported')

    wandb_run_prefix = f'{MODEL_SIZE}M-tpp={TOKENS_RATIO}'
    gs.add_param('wandb_run_prefix', wandb_run_prefix)
    gs.add_param('wandb_group', Template(wandb_group))
    gs.add_param('wandb_job_type', Template(wandb_job_type))
    gs.add_param('wandb_project', wandb_project)

    gs.run(
        torchrun=True,
        launch_blocking=0,
        scheduling=dict(
            distributed_training=dist,
            max_jobs_per_gpu=max_jobs,
            gpus=gpus,
            params_values=param_dict,
        ),
        param_name_for_exp_root_folder='results_base_folder',
        exp_folder=Template(f'./exps/{gs._wandb_project}/{wandb_run_prefix}_{wandb_group}_{wandb_job_type}'))

if __name__ == '__main__':
    get_arg_parse()

    # MODEL_SIZE = 30
    MODEL_SIZE = 350
    # MODEL_SIZE = 800
    # MODEL_SIZE = 1300

    main(
        MODEL_SIZE=MODEL_SIZE,

        # OPTIM='adamw',
        # OPTIM='galoreadamw',
        # OPTIM='ldadamw',
        # OPTIM='dct-adamw',
        OPTIM='trion',
        # OPTIM='frugal',
        # OPTIM='fira',
        # OPTIM='apollo',

        TOKENS_RATIO=1, # for debugging
        # TOKENS_RATIO=20,
        # TOKENS_RATIO=100,

        gpus=[
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7
        ],

        wandb_project='dct',

        dist=True, # use DDP
        # dist=False, # use single GPU

        param_dict={
            'lr': [
                # '7.5e-5',
                # '1e-4',
                # '2.5e-4',
                '1e-2',
            ],
            'lowrank_rank': [
                128,
                # 256,
                # 512,
            ],
            'lowrank_rotate_states': [
                # 0,
                1,
            ],
            'lowrank_upd_gap': [
                1,
                # 200,
            ],

            'lowrank_proj': [
                # 'svd',
                'dct',
                # 'hdm',
                # 'random',
                # 'randperm',
            ],
            'lowrank_use_ef': [
                '0',
                # '1'
            ],
            'lowrank_q_ef': [
                '0',
                # '4',
                # '8',
            ],
            'lowrank_distributed': [
                '1',
            ],
            'lowrank_max_shape': [
                # '0',
                '32000'
            ],

            ##### Trion params:
            'muon_ns_type': [
                'torch',
                # 'triton',
            ],
            'scaling_type': [
                'kj',
                # 'none',
                # 'kimi',
                # 'dion',
            ],
            'use_makhoul': [
                0,
                # 1,
            ]
        },
        max_jobs=1,
    )
