from typing import Any, Dict, Optional
from datetime import datetime
import re
import glob
import os

from huggingface_hub import login as hf_login

from llamafactory.hparams import get_train_args

from llamafactory.train.callbacks import LogCallback

# hf_login(token='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')

date = datetime.now().strftime('%m%d')


def delete_checkpoints(output_dir):
    checkpoint_paths = glob.glob(os.path.join(output_dir, "checkpoint-*"))
    for path in checkpoint_paths:
        if os.path.isdir(path):
            # Remove directory and all its contents
            import shutil

            shutil.rmtree(path)
        elif os.path.isfile(path):
            os.remove(path)


def main(args: Optional[Dict[str, Any]] = None):
    callbacks = [LogCallback()]
    model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)

    if data_args.eval_dataset is not None and training_args.predict_with_generate:
        for dataset_name in data_args.eval_dataset:
            if 'spider' in dataset_name:
                from utils.spider_utils import download_punkt

                download_punkt()
                break

    if 'wandb' in training_args.report_to:
        import wandb

        wandb.login(key='1fbcba842bcf37b62d1367455097fade50302007')

        if training_args.output_dir.split('/')[-1].startswith('eval_results'):
            eval_dataset = data_args.eval_dataset[0].split('_')[0]
            if model_args.enable_mem:
                if model_args.update_while_predicting:
                    update_flag = '-uwp'
                else:
                    update_flag = '-nuwp'
            else:
                update_flag = ''
            wandb_group_name = (
                f'test_{eval_dataset}-' + training_args.output_dir.split('/')[-2] + f'{update_flag}-{date}'
            )
        else:
            train_dataset = data_args.dataset[0].split('_')[0]
            wandb_group_name = f'train_{train_dataset}-' + training_args.output_dir.split('/')[-1] + f'-{date}'
        tag = wandb_group_name.split('-')[0]
        wandb.init(project='MeMv2', group=wandb_group_name, tags=tag, resume=False)

        shot_n = re.search(r'shot(\d+)', wandb_group_name)
        if shot_n:
            shot_n = shot_n.group(1)
        else:
            shot_n = 0
        wandb.config['shot_n'] = shot_n

        if 'ict' in wandb_group_name:
            wandb.config['method'] = 'ict'
        elif 'mem' in wandb_group_name:
            wandb.config['method'] = 'mem'

    if model_args.enable_mem:
        from llamafactory.train.sft import run_mem_sft

        run_mem_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)

    elif finetuning_args.finetuning_type == "prefix":
        from llamafactory.train.sft import run_prefix_sft
        run_prefix_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
    else:
        from llamafactory.train.sft import run_sft

        run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)

    delete_checkpoints(training_args.output_dir)


if __name__ == '__main__':
    main()
