"""
Script to distill pretrained Transformers into linear attention variants
"""
import sys
import os
from os.path import join

import argparse
import torch
from omegaconf import OmegaConf
sys.path.append('./src')
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from utils.setup import (
    seed_everything, get_run_name_from_args,
    update_config_from_args, update_model_config_from_args,
)
from utils.logging import print_config, print_header

from dataloaders import load_data
from trainer import get_trainer, get_optimizer, get_scheduler

from model.pretrained import get_pretrained_loader
from model.load_model import load_and_convert_attns
from model.convert_model import toggle_attention, remove_base_attention
from model.utils import count_parameters
import torch.distributed as dist
import datetime
from utils.rotation_utils import add_rotations


def get_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--project_name", type=str, default='kvlinc')
    parser.add_argument("--model_config", type=str, default=None)
    parser.add_argument("--distill_config", type=str, default=None)

    parser.add_argument("--pretrained_model_name_or_path", type=str, default=None)
    parser.add_argument("--load_checkpoint", type=str, default=None)
    parser.add_argument("--resume_distill", action='store_true', default=None)
    
    # Override default configs
    # Feature map / model
    parser.add_argument("--attention_type", type=str, default=None)
    parser.add_argument("--learned_kernel", type=str, default=None)  # always
    parser.add_argument("--lk_skip_connection", action='store_true', default=None)
    parser.add_argument("--lk_zero_init", action='store_true', default=None)
    parser.add_argument("--lk_normal_init", action='store_true', default=None)
    parser.add_argument("--tie_qk_kernels", action='store_true', default=None)
    parser.add_argument("--train_qk", action='store_true', default=None)
    parser.add_argument("--state_chunk_len", type=int, default=None)
    parser.add_argument("--window_size", type=int, default=None)
    
    # Miscellaneous
    parser.add_argument("--huggingface_token", type=str, default=None)
    parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--verbose", action='store_true', default=None)
    parser.add_argument("--resq_rotation_path", type=str, default='./rotations/R.bin')


    args = parser.parse_args()
    args.run_name = get_run_name_from_args(args)
    return args

def get_local_rank() -> int:
    if os.environ.get("LOCAL_RANK"):
        return int(os.environ["LOCAL_RANK"])
    else:
        return torch.distributed.get_rank()
    
def main():
    # ------
    # SET UP
    # ------
    args = get_args()
    
    seed_everything(args.seed)
    args.device = torch.device('cuda')
    dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
    local_rank = get_local_rank()

    print("the rank is {}".format(local_rank))
    torch.distributed.barrier()


    # Load distillation + (hedgehog) attention configs
    distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml')
    distill_config = OmegaConf.load(distill_config_path)
    distill_config = update_config_from_args(distill_config, args)

    model_config_path = join('./configs/model', f'{args.model_config}.yaml')
    model_config = OmegaConf.load(model_config_path)
    model_config = update_model_config_from_args(model_config, args)

    seq_len = f"seq_len_{distill_config.dataset.dataset_config.chunk_size}"
    feature_dim = f"feature_dim_{model_config.attention.learned_kernel_kwargs.feature_dim}"

    args.checkpoint_dir = join(args.checkpoint_dir, args.model_config, seq_len, feature_dim)
    if local_rank == 0:
        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)
    torch.distributed.barrier()
        
    args.run_name = args.run_name.replace('True', '1').replace('False', '0')  # concise hacks
    if hasattr(model_config.attention,'apply_rotations'):
        args.apply_rot = model_config.attention.apply_rotations
    else:
        args.apply_rot = False
    

    # Update data tokenizer to match model
    for k in ['pretrained_model_name_or_path', 'cache_dir']:
        distill_config.dataset.pretrained_model_config[k] = model_config.model[k]

    # Update optimizer if specified
    if 'optimizer' in model_config:
        for k, v in model_config.optimizer.items():
            distill_config.optimizer[k] = v

    print_header('Distillation Config')
    print_config(distill_config)
    print_header('Model Config')
    print_config(model_config)

    # Get pretrained model
    model_loader = get_pretrained_loader(**model_config.model,
                                         huggingface_token=args.huggingface_token)
    tokenizer = model_loader.load_tokenizer()
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'left'

    # Convert model
    try:
        args.attention_type = model_config['attention']['attention_type']
    except AttributeError:
        args.attention_type = 'lolcats_llama'
    model = model_loader.load(model_type=args.attention_type)
    model.config.kvquant = model_config.attention.kvquant

    if args.verbose:
        print_header('*** Initial Model ***')
        print(model)
    # --------
    # TRAINING
    # --------
    # 1. Distill attentions
    if args.load_checkpoint is None or args.resume_distill:
        if args.resume_distill:
            checkpoint_path = args.load_distill_checkpoint
        else:
            checkpoint_path = None
        # Swap initial attentions if applicable
        model = load_and_convert_attns(model, model_config, 
                                    attention_type=args.attention_type, 
                                    checkpoint_path=checkpoint_path, 
                                    print_model=args.verbose,
                                    train_attention=True)
        if args.apply_rot:
            model = add_rotations(model, args)
        
        if distill_config.trainer.name is not None:  # Get data for distilling
            dataloaders  = load_data(distill_config.dataset, distill_config.dataloader)
            train_loader = dataloaders[distill_config.trainer.train_split]
            eval_loader  = dataloaders[distill_config.trainer.val_split]
                    

            if args.verbose:
                print_header('*** Dataset preview ***')
                for ix, data in enumerate(train_loader):
                    print('-> Train data input_ids.shape:', data['input_ids'].shape)
                    break
                for ix, data in enumerate(eval_loader):
                    print('-> Eval  data input_ids.shape:', data['input_ids'].shape)
                    break
                
                for ix, data in enumerate(dataloaders[distill_config.trainer.val_split]):
                    print('-> Prompt:')
                    print(tokenizer.batch_decode(data['input_ids'])[0])
                    if 'position_ids' in data:
                        print('-> Position IDs:')
                        print('shape:', data['position_ids'].shape)
                        print(data['position_ids'])
                    break
        
            # Log some stats
            distill_config.model_train_params = count_parameters(model, requires_grad=True)
            distill_config.model_total_params = count_parameters(model, requires_grad=False)
            pct_trainable = distill_config.model_train_params / distill_config.model_total_params
        
            print_header('*** Distillation Parameter Counts ***')
            print(f'├── Number training to distill:  {distill_config.model_train_params}')
            print(f'├── Number of total parameters:  {distill_config.model_total_params}')
            print(f'├── Percent training to distill: {pct_trainable * 100:.3f}%')
        
            # Get optimizer and scheduler
            optimizer = get_optimizer(model=model, **distill_config.optimizer)
            scheduler = get_scheduler(optimizer=optimizer, **distill_config.lr_scheduler)
        
            # Load trainer 
            for arg, argv in distill_config.trainer.items():
                if arg != 'name':
                    setattr(args, arg, argv)
            for _config in ['dataloader', 'optimizer', 'lr_scheduler']:
                setattr(args, _config, OmegaConf.to_container(getattr(distill_config, _config)))
        
            OurTrainer = get_trainer(distill_config.trainer.name)
            trainer = OurTrainer(model=model, 
                                 args=args,
                                 train_loader=train_loader,
                                 eval_loader=eval_loader,
                                 optimizer_and_scheduler=(optimizer, scheduler),
                                 device=args.device,
                                 checkpoint_suffix='_distill',
                                 save_results=False,
                               **distill_config.trainer)
            

            # Train / distill model
            print_header('*** Distilling Attentions ***')
            print(f'├── Experiment name: {args.run_name}')
            print(f'├── Device: {args.device}')
            print(f'├── Seed: {args.seed}')
            model = toggle_attention(model, train=True)

            torch.distributed.barrier()
            model = trainer.train()
            torch.distributed.barrier()
            model = model.cpu()
            torch.cuda.empty_cache()


            # Prepare for downstream finetune / eval
            model = toggle_attention(model, train=False)
            model = remove_base_attention(model)

        else:
            print('-> No distillation')
            
    


if __name__ == '__main__':
    main()
