"""
Script to distill pretrained Transformers into linear attention variants
"""
import sys
import os
from os.path import join

import argparse
import torch
import time
from omegaconf import OmegaConf
sys.path.append('./src')
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from utils.setup import (
    seed_everything, get_run_name_from_args,
    update_model_config_from_args,
)
from utils.logging import print_config, print_header

from model.pretrained import get_pretrained_loader
from model.load_model import load_and_convert_attns
import torch.distributed as dist
import datetime
import transformers
import random
import datasets
from utils.rotation_utils import add_rotations

def get_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--project_name", type=str, default='kvlinc')
    parser.add_argument("--model_config", type=str, default=None)

    parser.add_argument("--load_distill_checkpoint", type=str, default=None)
    parser.add_argument("--huggingface_token", type=str, default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--prefill_len", type=int, default=1024)
    parser.add_argument("--generate_len", type=int, default=512)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--bench_steps", type=int, default=5)
    parser.add_argument("--warmup_steps", type=int, default=2)



    args = parser.parse_args()
    return args

def get_local_rank() -> int:
    if os.environ.get("LOCAL_RANK"):
        return int(os.environ["LOCAL_RANK"])
    else:
        return torch.distributed.get_rank()

def main():
    # ------
    # SET UP
    # ------
    args = get_args()
    seed_everything(args.seed)
    args.device = torch.device('cuda')
    dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
    local_rank = get_local_rank()

    print("the rank is {}".format(local_rank))
    torch.distributed.barrier()


    model_config_path = join('./configs/model', f'{args.model_config}.yaml')
    model_config = OmegaConf.load(model_config_path)
    
    args.k_bits = model_config.attention.k_bits
    args.v_bits = model_config.attention.v_bits

    print_header('Model Config')
    print_config(model_config)


    # Get pretrained kvlinc model
    model_loader = get_pretrained_loader(**model_config.model,
                                         huggingface_token=args.huggingface_token)
    tokenizer = model_loader.load_tokenizer()
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'left'

    ######## KVLINC MODEL ########
    kvlinc_model = model_loader.load(model_type='kv_linc')
    kvlinc_model.config.kvquant = {"nbits":model_config.attention.nbits,
                            "residual_length":model_config.attention.window_size,
                            "q_group_size": model_config.attention.quant_group_size}
    
    kvlinc_model, _ = load_and_convert_attns(kvlinc_model, model_config, 
                                    attention_type='kv_linc', 
                                    checkpoint_path=args.load_distill_checkpoint, 
                                    print_model=False,
                                    merge_loras=False,
                                    train_attention=False)
    args.apply_rot = model_config.attention.apply_rotations
    kvlinc_model = add_rotations(kvlinc_model, args)
        
    kvlinc_model.cuda()
    kvlinc_model.eval()

    
    ######## FLASHATTN MODEL ########
    flashattn_model = model_loader.load(model_type='flash_attention_2')
    flashattn_model.cuda()
    flashattn_model.eval()

    torch.distributed.barrier()
    prefill_len = args.prefill_len
    generate_len = args.generate_len
    bs = args.batch_size
    
    testenc = get_wikitext2(
            seed=args.seed,
            seqlen=prefill_len,
            tokenizer=tokenizer,
            eval_mode=True,
            vision=False,
        )
    input_ids = testenc.input_ids  # (1, text_len)
    nsamples = input_ids.numel() // prefill_len  # The tail is truncated.
    input_ids = (
        input_ids[:, : nsamples * prefill_len].view(nsamples, prefill_len)
    )  # (nsamples, seqlen)
    prompt = input_ids[0].unsqueeze(0).repeat(bs,1).to("cuda")
    bench_steps = args.bench_steps
    warmup_steps = args.warmup_steps
    with torch.no_grad():
        # kvlinc timing
        torch.cuda.synchronize()
        for _ in range(warmup_steps):
            out_kvlinc = kvlinc_model.generate(prompt, max_new_tokens=generate_len)
        
        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(bench_steps):
            out_kvlinc = kvlinc_model.generate(prompt, max_new_tokens=generate_len)

        torch.cuda.synchronize()
        ms_kvlinc = (time.time() - t0) * 1000 / bench_steps

        torch.cuda.empty_cache()
        # flashattn timing
        torch.cuda.synchronize()
        for _ in range(warmup_steps):
            out_flash = flashattn_model.generate(prompt, max_new_tokens=generate_len)
        
        torch.cuda.synchronize()
        t0 = time.time()
        for _ in range(bench_steps):
            out_flash = flashattn_model.generate(prompt, max_new_tokens=generate_len)

        torch.cuda.synchronize()
        ms_flash = (time.time() - t0) * 1000 / bench_steps
      
    print(f"Avg time: KVLINC={ms_kvlinc:.3f} ms   Reference={ms_flash:.3f} ms\n")

def get_wikitext2(
    nsamples=128,
    seed=0,
    seqlen=2048,
    model="",
    tokenizer=None,
    eval_mode=False,
    vision=False,
):
    print("get_wikitext2")

    if tokenizer is None:
        if vision:
            tokenizer = transformers.AutoProcessor.from_pretrained(model)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model, use_fast=False
            )

    if eval_mode:
        testdata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
            "test"
        ]
        testenc = tokenizer(text="\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc
    else:
        traindata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
            "train"
        ]
        trainenc = tokenizer(text="\n\n".join(traindata["text"]), return_tensors="pt")
        random.seed(seed)
        trainloader = []
        for _ in range(nsamples):
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            tar = inp.clone()
            tar[:, :-1] = -100
            trainloader.append((inp, tar))
        return trainloader


if __name__ == '__main__':
    main()
