import copy
import glob
import logging
import os
import re
import subprocess
import sys
import random
from datetime import datetime
from functools import partial

import numpy as np
import pandas as pd
import torch
from torch import optim

try:
    import wandb
except ImportError:
    wandb = None

try:
    import torch.utils.tensorboard as tensorboard
except ImportError:
    tensorboard = None

try:
    import horovod.torch as hvd
except ImportError:
    hvd = None

from open_clip_train.retrieval import recall_at_k

from open_clip import create_model_and_transforms_probing, trace_model, get_tokenizer, create_loss
from open_clip_train.data import get_data
from open_clip_train.distributed import is_master, init_distributed_device, broadcast_object
from open_clip_train.logger import setup_logging
from open_clip_train.params import parse_args
from open_clip_train.scheduler import cosine_lr, const_lr, const_lr_cooldown
from open_clip_train.train_probing import train_one_epoch, evaluate
from open_clip_train.file_utils import pt_load, check_exists, start_sync_process, remote_sync





class CFG_modifier():
    def __init__(self, **kwargs):
        for name, value in kwargs.items():
            setattr(self,name,value)

    def __call__(self, cfg):
        # bit setting
        cfg.bit = self.bit_setting
        cfg.w_bit = {name: self.bit_setting[0] for name in cfg.conv_fc_name_list}
        cfg.a_bit = {name: self.bit_setting[1] for name in cfg.conv_fc_name_list}
        cfg.A_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}
        cfg.B_bit = {name: self.bit_setting[1] for name in cfg.matmul_name_list}

        # conv2d configs
        cfg.ptqsl_conv2d_kwargs["n_V"] = self.linear_ptq_setting[0]
        cfg.ptqsl_conv2d_kwargs["n_H"] = self.linear_ptq_setting[1]
        cfg.ptqsl_conv2d_kwargs["metric"] = self.metric
        cfg.ptqsl_conv2d_kwargs["init_layerwise"] = False

        # linear configs
        cfg.ptqsl_linear_kwargs["n_V"] = self.linear_ptq_setting[0]
        cfg.ptqsl_linear_kwargs["n_H"] = self.linear_ptq_setting[1]
        cfg.ptqsl_linear_kwargs["n_a"] = self.linear_ptq_setting[2]
        cfg.ptqsl_linear_kwargs["metric"] = self.metric
        cfg.ptqsl_linear_kwargs["init_layerwise"] = False

        # matmul configs
        cfg.ptqsl_matmul_kwargs["metric"] = self.metric
        cfg.ptqsl_matmul_kwargs["init_layerwise"] = False

        return cfg

LATEST_CHECKPOINT_NAME = "epoch_latest.pt"

def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)


def natural_key(string_):
    """See http://www.codinghorror.com/blog/archives/001018.html"""
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def get_latest_checkpoint(path: str, remote : bool):
    # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders
    if remote:
        result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if result.returncode == 1:
            return None
        checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]]
    else:
        checkpoints = glob.glob(path + '**/*.pt', recursive=True)
    if checkpoints:
        checkpoints = sorted(checkpoints, key=natural_key)
        return checkpoints[-1]
    return None

def main(args):
    args = parse_args(args)
    timm_kwargs = {
    "pos_embed_order": args.pos_embed_order,
    "wbits": args.wbits,
    "abits": args.abits,
    "w_quant_type": args.w_quant_type,
    "a_quant_type": args.a_quant_type, 
    "register_num": args.register_num,
    "prefix_type": args.prefix_type,
    "embed_prefix": args.embed_prefix,
    "q_group_size": args.q_group_size,
    "num_classes": args.num_classes,
    "mask_percentage": args.masking_percentage,
    "model_name": args.model_name,
    }

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    device = init_distributed_device(args)

    if args.name is None:
        model_name_safe = args.model.replace('/', '-')
        date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
        if args.distributed:
            # sync date_str from master to all ranks
            date_str = broadcast_object(args, date_str)
        args.name = '-'.join([
            date_str,
            f"model_{model_name_safe}",
            f"lr_{args.lr}",
            f"b_{args.batch_size}",
            f"j_{args.workers}",
            f"p_{args.precision}",
        ])

    resume_latest = args.resume == 'latest'
    log_base_path = os.path.join(args.logs, args.name)
    args.log_path = None
    if is_master(args, local=args.log_local):
        os.makedirs(log_base_path, exist_ok=True)
        log_filename = f'out-{args.rank}' if args.log_local else 'out.log'
        args.log_path = os.path.join(log_base_path, log_filename)
        if os.path.exists(args.log_path) and not resume_latest:
            print(
                "Error. Experiment already exists. Use --name {} to specify a new experiment."
            )
            return -1

    args.log_level = logging.DEBUG if args.debug else logging.INFO
    setup_logging(args.log_path, args.log_level)
    args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
    args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
    args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
    if is_master(args):
        args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ''
        for dirname in [args.tensorboard_path, args.checkpoint_path]:
            if dirname:
                os.makedirs(dirname, exist_ok=True)
    else:
        args.tensorboard_path = ''

    if args.wandb and is_master(args):
        assert wandb is not None, 'Please install wandb.'
        logging.debug('Starting wandb.')
        wandb.init()
        if args.debug:
            wandb.watch(model, log='all')
        logging.debug('Finished loading wandb.')
        args.sweep_config = wandb.config

        for key, val in wandb.config.items():
            timm_kwargs[key] = val

    if resume_latest:
        resume_from = None
        checkpoint_path = args.checkpoint_path
        if args.remote_sync is not None:
            checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints")
            if args.save_most_recent:
                print('Error. Cannot use save-most-recent with remote_sync and resume latest.')
                return -1
            if args.remote_sync_protocol != 's3':
                print('Error. Sync protocol not supported when using resume latest.')
                return -1
        if is_master(args):
            if args.save_most_recent:
                resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME)
                if not os.path.exists(resume_from):
                    resume_from = None
            else:
                resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None)
            if resume_from:
                logging.info(f'Found latest resume checkpoint at {resume_from}.')
            else:
                logging.info(f'No latest resume checkpoint found in {checkpoint_path}.')
        if args.distributed:
            resume_from = broadcast_object(args, resume_from)
        args.resume = resume_from

    if args.copy_codebase:
        copy_codebase(args)

    remote_sync_process = None
    if is_master(args) and args.remote_sync is not None:
        result = remote_sync(
            os.path.join(args.logs, args.name), 
            os.path.join(args.remote_sync, args.name), 
            args.remote_sync_protocol
        )
        if result:
            logging.info('remote sync successful.')
        else:
            logging.info('Error: remote sync failed. Exiting.')
            return -1
        remote_sync_process = start_sync_process(
            args.remote_sync_frequency,
            os.path.join(args.logs, args.name), 
            os.path.join(args.remote_sync, args.name), 
            args.remote_sync_protocol
        )
        remote_sync_process.start()

    if args.precision == 'fp16':
        logging.warning(
            'It is recommended to use AMP mixed-precision instead of FP16. '
            'FP16 support needs further verification and tuning, especially for train.')
    if args.horovod:
        logging.info(
            f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.'
            f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
    elif args.distributed:
        logging.info(
            f'Running in distributed mode with multiple processes. Device: {args.device}.'
            f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
    else:
        logging.info(f'Running with a single process. Device {args.device}.')

    dist_model = None
    args.distill = args.distill_model is not None and args.distill_pretrained is not None
    if args.distill:
        assert args.accum_freq == 1
        assert 'coca' not in args.model.lower()

    if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1:
        # arg is nargs, single (square) image size list -> int
        args.force_image_size = args.force_image_size[0]
    random_seed(args.seed, 0)
    model_kwargs = {}
    if args.siglip:
        model_kwargs['init_logit_scale'] = np.log(10)  # different from CLIP
        model_kwargs['init_logit_bias'] = -10


    target_keys = [
        "prefix_add",
        "cache_option",
        "prefix_add_block",
        "prefix_number",
        "token_delete",
        "token_delete_block",
        "token_delete_number",
        "token_delete_previous_layer",
        "token_delete_method",
        "head_index_for_score",
        "global_prefix_rank",
        "target_block",
        "target_layer",
        "bit",
    ]

##############################🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻🕵🏻################################
    if args.eval_only: 
        if args.ms_coco: 
            result, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
            print("=== MS-COCO Zero-Shot Retrieval Results ===")
            for key in sorted(result.keys()):
                print(f"{key:20s} : {result[key]:.4f}")
        else: 
            acc1, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
            print("ACC:", acc1)

        return
    # #search initialize
    prefix_number_candidates = [1, 2, 4, 8]
    args.target_block_best = {}
    args.global_prefix_rank = 1
    args.prefix_number_idx = 0
    args.prefix_number = 1
    args.prefix_add_block = list(range(args.target_block, 12))
    min_target_block = 3

    best_acc1 = 0 
    best_args = None
    args.prefix_search_mode = True
    while True:
        #prefix search 
        acc1, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
        if acc1 > best_acc1: 
            best_acc1 = acc1
            best_args = selected_config
        args.global_prefix_rank += 1  

        if args.global_prefix_rank > args.max_global_prefix_rank:
            args.global_prefix_rank = 1  
            args.prefix_number_idx += 1
            if args.prefix_number_idx >= len(prefix_number_candidates):
                args.prefix_number_idx = 0
                args.target_block -= 1  
                if args.target_block < min_target_block: 
                    break
                args.prefix_add_block = list(range(args.target_block, 12))
            args.prefix_number = prefix_number_candidates[args.prefix_number_idx]
        log_process(args, selected_config, acc1)
    log_result(args, best_args, best_acc1)
    for k, v in best_args.items():
        print(f"{k}: {v}")
    for key in ["target_block", "prefix_add_block", "global_prefix_rank"]:
        if key in best_args:
            setattr(args, key, best_args[key])
    prefix number search 
    args.prefix_number = 1
    best_acc1 = 0 
    best_args = None
    args.prefix_num_search_mode = True
    args.prefix_search_mode = False
    while True:
        acc1, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
        if acc1 > best_acc1: 
            best_acc1 = acc1
            best_args = selected_config
        args.prefix_number += 1 
        if args.prefix_number > args.max_prefix_number:
            break
        log_process(args, selected_config, acc1)
    # log_result(args, best_args, best_acc1)
    
    #eval for best prefix 
    for key in ["prefix_number"]:
        if key in best_args:
            setattr(args, key, best_args[key])
    args.imagenet_val = "/data/ILSVRC2012/val"
    acc1, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
    log_result(args, selected_config, acc1)

    #prefix delete search
    args.imagenet_val = "/home/user/data/ILSVRC2012/train_sample_50_per_class"
    args.prefix_num_search_mode = False
    args.num_of_delete_search_mode=True
    for key in ["prefix_number"]:
        if key in best_args:
            setattr(args, key, best_args[key])
    args.token_delete = True
    args.token_delete_number = 0
    best_acc1 = 0 
    best_args = None
    while True:
        acc1, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
        if acc1 > best_acc1: 
            best_acc1 = acc1
            best_args = selected_config
        args.token_delete_number += 1
        if args.token_delete_number > args.max_token_delete_number:
            break
        log_process(args, selected_config, acc1)
    #eval for best delete 
    for key in ["token_delete_number"]:
        if key in best_args:
            setattr(args, key, best_args[key])
    args.imagenet_val = "/data/ILSVRC2012/val"
    acc1, selected_config = run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs)
    log_result(args, selected_config, acc1)

    print("✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨")
    print("Best acc1:", best_acc1)
    print("Best config")
    log_result(args, best_args, best_acc1)
    for k, v in best_args.items():
        print(f"{k}: {v}")   

    if args.wandb and is_master(args):
        wandb.finish()
    return 

def log_process(args, selected_config, acc1):
    if args.cifar_100_train or args.cifar_100_val: 
        data = "cifar100"
    elif args.imagenet_val:
        data = "imagenet1k"
    elif args.caltech101:
        data = "caltech-101"
    elif args.stanford:
        data = "stanford"
    elif args.flowers_102:
        data = "flowers-102"
    elif args.ucf101:
        data = "ucf101"
    elif args.food_101:
        data = "food-101"
    # 1. base dir 선택
    if getattr(args, "prefix_search_mode", False):
        base_dir = f"/home/user/regcache/result/last_search_process/{data}/prefix_search"
    elif getattr(args, "prefix_num_search_mode", False):
        base_dir = f"/home/user/regcache/result/last_search_process/{data}_val/prefix_num_search"
    elif getattr(args, "num_of_delete_search_mode", False):
        base_dir = f"/home/user/regcache/result/last_search_process/{data}_val/delete_num_search"
    else:
        base_dir = f"/home/user/regcache/result/last_search_process/{data}/other"

    os.makedirs(base_dir, exist_ok=True)
    save_path = os.path.join(base_dir, f"search_results_{args.model_name}_{args.baseline}_{args.bit}bit.csv")

    # 2. 기록할 row 생성
    row = {**selected_config, "acc1": acc1}

    # 3. CSV append 저장
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
    else:
        df = pd.DataFrame([row])

    df.to_csv(save_path, index=False)
    print(f"[LOGGED] Result saved at {save_path}")

def log_result(args, selected_config, acc1):
    # 1. base dir 선택
    base_dir = "/home/user/regcache/result/last_search_process/"

    os.makedirs(base_dir, exist_ok=True)
    save_path = os.path.join(base_dir, f"{args.model_name}_{args.baseline}_{args.bit}bit.csv")

    # 2. 기록할 row 보기 좋게 문자열화
    log_dict = {**selected_config, "acc1": acc1}
    formatted_row = "\n".join([f"{k:<35} {v}" for k, v in log_dict.items()])

    # 3. 결과 출력
    print("✨" * 20)
    print(formatted_row)

    # 4. CSV append 저장 (문자열 형태로 저장)
    row = {"result": formatted_row}
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
    else:
        df = pd.DataFrame([row])

    df.to_csv(save_path, index=False)
    print(f"[LOGGED] Result saved at {save_path}")

def run_eval_with_config(args, target_keys, device, model_kwargs, timm_kwargs):
    timm_kwargs.update({
        k: getattr(args, k)
        for k in target_keys
        if hasattr(args, k) and getattr(args, k) is not None
    })

    model, preprocess_train, preprocess_val = create_model_and_transforms_probing(
        args.model,
        args.pretrained,
        precision=args.precision,
        device=device,
        jit=args.torchscript,
        force_quick_gelu=args.force_quick_gelu,
        force_custom_text=args.force_custom_text,
        force_patch_dropout=args.force_patch_dropout,
        force_image_size=args.force_image_size,
        image_mean=args.image_mean,
        image_std=args.image_std,
        image_interpolation=args.image_interpolation,
        image_resize_mode=args.image_resize_mode,  # only effective for inference
        aug_cfg=args.aug_cfg,
        pretrained_image=args.pretrained_image,
        output_dict=True,
        cache_dir=args.cache_dir,
        timm_kwargs=timm_kwargs,
        **model_kwargs,
    )

    # 3. 데이터 준비
    random_seed(args.seed, args.rank)
    start_epoch = 0
    tokenizer = get_tokenizer(args.model, cache_dir=args.cache_dir)
    if args.baseline != 'noisyquant' or args.baseline!='repqvit':
        data = get_data(
            args,
            (preprocess_train, preprocess_val),
            epoch=start_epoch,
            tokenizer=tokenizer,
        )
        assert len(data), 'At least one train or eval dataset must be specified.'

    # 4. 로그 여부 결정
    args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args)
    writer = None

    # 5. 평가
    print("_____________________[Config]_______________________")
    print("prefix number", args.prefix_number)
    print("num to delete", args.token_delete_number)
    print("____________________________________________")


    if args.baseline == 'awq': 
        if args.model_name in ('clip', 'openclip'): 
            from open_clip_train.llm_awq.awq.quantize.clip_awq import run_awq, apply_awq
        elif args.model_name in ('siglip', 'siglip2'):
            from open_clip_train.llm_awq.awq.quantize.siglip_awq import run_awq, apply_awq
        q_config = {
        "zero_point": False, 
        "q_group_size": 128, 
        "inplace": False, 
        "get_scale_zp": False
        }
        bit = args.bit
        # awq_results = run_awq(model, data['imagenet-val'].dataloader, bit, q_config)
        # torch.save(awq_results,f'/home/user/regcache/src/open_clip_train/llm_awq/awq_results/{model_name}/{bit}bit_128group.pt')
        # print("awq results saved!")
        # exit()
        awq_results = torch.load(f'/home/user/regcache/src/open_clip_train/llm_awq/awq_results/{args.model_name}/{bit}bit_128group.pt', map_location="cuda")
        apply_awq(model, awq_results)
        model.to('cuda')
    elif args.baseline == 'ptq4vit': 
        import open_clip_train.PTQ4ViT.utils.net_wrap as net_wrap
        import open_clip_train.PTQ4ViT.utils.datasets as datasets
        from open_clip_train.PTQ4ViT.utils.quant_calib import HessianQuantCalibrator
        from open_clip_train.PTQ4ViT.utils.models import get_net_model
        from open_clip_train.PTQ4ViT.utils.net_wrap import wrap_certain_modules_in_net
        from open_clip_train.PTQ4ViT.example.test_vit import init_config

        config_name = "PTQ4ViT"
        metric = "hessian"
        linear_ptq_setting = (1,1,1) # n_V, n_H, n_a
        calib_size =32
        bit = args.wbits
        bit_setting = (bit,bit) # weight, activation
        
        quant_cfg = init_config(config_name)
        cfg_modifier = CFG_modifier(linear_ptq_setting=linear_ptq_setting, metric=metric, bit_setting=bit_setting)
        quant_cfg = cfg_modifier(quant_cfg)
        
        net = get_net_model(model.visual)
        wrapped_modules=net_wrap.wrap_modules_in_net(net, quant_cfg) 
        g=datasets.ViTImageNetLoaderGenerator('/data/ILSVRC2012/','imagenet',32,32,16, kwargs={"model":net})
        calib_loader=g.calib_loader(num=calib_size)

        quant_calibrator = HessianQuantCalibrator(net,wrapped_modules,calib_loader,sequential=False,batch_size=4) 
        quant_calibrator.quant_calib()
        model.visual = net
    elif args.baseline == 'noisyquant': 
        from open_clip_train.NoisyQuant.fast_quant import fast_quant

        original_batch_size = args.batch_size
        args.batch_size = args.calib_num

        data = get_data(
        args,
        (preprocess_train, preprocess_val),
        epoch=start_epoch,
        tokenizer=tokenizer,
        )
        if args.ms_coco: 
            import torch.utils.data as dutils
            dataset = data['calib-coco']
            dataloader = dutils.DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
            calib_loader=dataloader
        else: 
            calib_loader=data['calib'].dataloader
       
        model = fast_quant(model, bit=args.bit, with_noisy_quant=args.with_noisy_quant, 
                        percentile=args.percentile,
                        search_noisy=args.search_noisy, search_mean=args.search_mean)
        with torch.no_grad():
            for batch_idx, (input, target) in enumerate(calib_loader):
                target = target.to(device)
                input = input.to(device)
                torch.save(input, f"./calib_data_{args.calib_num}.pt")
                break
            # with amp_autocast():
            if args.percentile:
                # at first, do percentile
                print("Begin percentile search!")
                output = model(input)
                print("Finish percentile search!")
            if args.search_noisy or args.search_mean:
                print("Begin noisy bias search!")
                # then, search noisy bias
                output = model(input)
        print(f"Finished Calibration on {len(input)} samples!")
        model.eval()
        args.batch_size = original_batch_size
        data = get_data(
        args,
        (preprocess_train, preprocess_val),
        epoch=start_epoch,
        tokenizer=tokenizer,
        )
    elif args.baseline == 'repqvit':
        from open_clip_train.RepQ_ViT.classification.test_quant import repqvit
        original_batch_size = args.batch_size
        args.batch_size = args.calib_num
        data = get_data(
        args,
        (preprocess_train, preprocess_val),
        epoch=start_epoch,
        tokenizer=tokenizer,
        )
        if args.ms_coco: 
            import torch.utils.data as dutils
            dataset = data['calib-coco']
            dataloader = dutils.DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
            calib_loader=dataloader
        else: 
            calib_loader=data['calib'].dataloader
        model = repqvit(model, args, calib_loader)
        model.eval()
        args.batch_size = original_batch_size
        data = get_data(
        args,
        (preprocess_train, preprocess_val),
        epoch=start_epoch,
        tokenizer=tokenizer,
        )
    

    if args.ms_coco:  # zero-shot image retrieval 
        k_vals = [1, 5]
        t2i, i2t = recall_at_k(model, data, args.device, k_vals, args.batch_size)
        result = {}
        for k, x in zip(k_vals, t2i):
            result[f"ms-coco-t2i-R@{k}"] = float(x)
        for k, x in zip(k_vals, i2t):
            result[f"ms-coco-i2t-R@{k}"] = float(x)
        
        selected_config = {k: getattr(args, k, None) for k in target_keys}
        if args.eval_only: 
            return result, selected_config
        else: 
            avg_R1 = (t2i[0] + i2t[0]) / 2
            acc = float(avg_R1)
            return acc, selected_config



    else: #zero-shot image classification
        metrics = evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer)
        POSSIBLE_KEYS = [
        'caltech-101-zeroshot-val-top1',
        'cifar-100-zeroshot-val-top1',
        'imagenet-zeroshot-val-top1',
        'imagenetv2-zeroshot-val-top1',
        'imagenet-zeroshot-val-top5',  
        ]
        acc1 = next((float(metrics[k]) for k in POSSIBLE_KEYS if k in metrics), 0.0)

    selected_config = {k: getattr(args, k, None) for k in target_keys}
    return acc1, selected_config


def copy_codebase(args):
    from shutil import copytree, ignore_patterns
    new_code_path = os.path.join(args.logs, args.name, "code")
    if os.path.exists(new_code_path):
        print(
            f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
        )
        return -1
    print(f"Copying codebase to {new_code_path}")
    current_code_path = os.path.realpath(__file__)
    for _ in range(3):
        current_code_path = os.path.dirname(current_code_path)
    copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
    print("Done copying code.")
    return 1


if __name__ == "__main__":
    main(sys.argv[1:])

