"""
Script to distill pretrained Transformers into linear attention variants
"""
import sys
import os
from os.path import join

import argparse
import torch
sys.path.append('./src')
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from utils.setup import (
     seed_everything,  get_run_name_from_args,
    
)

import torch.distributed as dist
import datetime
from utils.rotation_utils import get_basis

def get_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--project_name", type=str, default='lolcats')
    parser.add_argument("--pretrained_model_name_or_path", type=str, default=None)

    # Miscellaneous
    parser.add_argument("--huggingface_token", type=str, default=None)
    parser.add_argument("--rotation_path", type=str, default='./rotations/R.bin')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--bf16", action='store_true', default=None)
    parser.add_argument("--high_frac", type=float, default=0.03125)
    parser.add_argument("--calib_seqlen", type=int, default=2048)
    parser.add_argument("--calib_samples", type=int, default=512)
    parser.add_argument("--calib_dataset", type=str, default="wikitext")
    parser.add_argument("--attn_implementation", type=str, default="flash_attention_2")

    args = parser.parse_args()
    return args

def get_local_rank() -> int:
    if os.environ.get("LOCAL_RANK"):
        return int(os.environ["LOCAL_RANK"])
    else:
        return torch.distributed.get_rank()
    
def main():
    # ------
    # SET UP
    # ------
    args = get_args()
    
    seed_everything(args.seed)
    args.device = torch.device('cuda')
    dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
    local_rank = get_local_rank()

    print("the rank is {}".format(local_rank))
    torch.distributed.barrier()
    get_basis(args)



    

        


if __name__ == '__main__':
    main()
