import os
import sys
import argparse
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

# dirty but working
import sys

sys.path.append(os.pardir)
from src.pruner import Pruner
from src.utils.random import fix_seed

from data_utils import get_loaders
from engine import eval_perplexity


def main():
    parser = argparse.ArgumentParser(description="One-shot pruning on ImageNet of timm models.")
    # Model params
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        required=True,
        help="The name or path to the model being pruned",
    )
    # Data params
    parser.add_argument(
        '--dataset_name_or_path',
        type=str,
        required=True,
        help="The name or dataset or path used for calibration.",
    )
    parser.add_argument(
        '--sequence_length',
        default=2048,
        type=int,
        help="length of extracted sequences."
    )
    # Sparsification params
    parser.add_argument(
        '--iterations',
        default=10,
        type=int
    )
    parser.add_argument(
        '--pruning_method',
        default="FastOBC",
        choices=["FastOBC", "OBC"],
        type=str
    )
    parser.add_argument(
        '--sparsity',
        default=0.5,
        type=float
    )
    parser.add_argument(
        '--module_regex',
        type=str,
        required=True,
        help="Modules to prune",
    )
    parser.add_argument(
        '--decoder_blocks',
        required=True,
        type=str
    )
    parser.add_argument(
        '--pre_decoder_modules',
        required=True,
        nargs="+",
        type=str
    )
    parser.add_argument(
        '--post_decoder_modules',
        required=True,
        nargs="+",
        type=str
    )
    parser.add_argument(
        '--calibration_dataset_size',
        default=None,
        type=int,
        help="Size of calibration dataset."
    )
    parser.add_argument(
        '--block_size',
        default=64,
        type=int
    )
    parser.add_argument(
        '--rows_in_parallel',
        default=None,
        type=int
    )
    parser.add_argument(
        '--perturbation',
        default='gradient',
        choices=['gradient', 'interpolation'],
        type=str
    )
    parser.add_argument(
        '--sequential',
        action='store_true',
        help='Whether to prune sequentially'
    )
    # Training params
    parser.add_argument(
        '--finetune_batch_size',
        default=1,
        type=int,
        help="Finetuning batch size."
    )
    parser.add_argument(
        '--lr',
        default=1e-5,
        type=float,
        help="Adam lr"
    )
    parser.add_argument(
        '--rel_damp',
        default=1e-2,
        type=float
    )
    # Misc params
    parser.add_argument(
        '--seed',
        default=0,
        type=int,
        help="random seed."
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Output directory where model checkpoints and results are stored.'
    )
    parser.add_argument(
        '--save_model',
        action='store_true',
        help='Whether to save pruned model'
    )
    parser.add_argument(
        '--alpha',
        default=0.0,
        type=float
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="auto",
        choices=["auto", "bfloat16", "float16", "float32"],
        help="dtype to load the model.",
    )
    parser.add_argument(
        '--low_cpu_mem_usage',
        action='store_true',
        help='Whether to load model with the use of `low_cpu_mem_usage`'
    )
    parser.add_argument(
        '--load_model_on_device',
        action='store_true',
        help='Whether to load model on device on pruning'
    )
    parser.add_argument(
        "--attn_implementation",
        type=str,
        default=None,
        choices=["eager", "sdpa", "flash_attention_2"],
        help="Attention implementation for both teacher and student models: eager, sdpa, or flash_attention_2",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether to use gradient checkpointing",
    )
    args = parser.parse_args()
    run(args)


def finetune(model, dataloader, args):
    device = "cuda"
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # Loop through each batch
    for input_ids in tqdm(dataloader, total=len(dataloader), desc="Finetuning model", leave=False):
        input_ids = input_ids.to(device=device)
        # Forward pass through the model
        lm_logits = model(input_ids).logits
        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:]
        # Compute loss
        loss = F.cross_entropy(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    torch.cuda.empty_cache()


def run(args):
    fix_seed(args.seed)
    # get device
    assert torch.cuda.is_available()
    device = "cuda"
    # model
    torch_dtype = args.dtype
    if torch_dtype != 'auto':
        torch_dtype = getattr(torch, args.dtype)

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=args.low_cpu_mem_usage,
        attn_implementation=args.attn_implementation
    )
    max_sequence_length = float('inf')
    if hasattr(model.config, 'max_sequence_length'):
        max_sequence_length = model.config.max_sequence_length
    elif hasattr(model.config, 'max_position_embeddings'):
        max_sequence_length = model.config.max_position_embeddings
    elif hasattr(model.config, 'max_seq_len'):
        max_sequence_length = model.config.max_seq_len
    assert args.sequence_length <= max_sequence_length
    model.sequence_length = args.sequence_length

    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # get tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    calibration_data = get_loaders(
        args.dataset_name_or_path,
        args.calibration_dataset_size,
        args.seed,
        model.sequence_length,
        False,
        tokenizer
    )
    data_loader = [([], {'input_ids': input_ids}) for input_ids in calibration_data]

    train_loader = DataLoader(
        data_loader,
        batch_size=args.finetune_batch_size,
        collate_fn=lambda samples: torch.cat([sample[1]['input_ids'] for sample in samples], dim=0),
    )

    # make dirs if needed
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
    else:
        assert not args.save_model

    if args.pruning_method == "FastOBC":
        obc_util_kwargs = {"block_size": args.block_size}
    elif args.pruning_method == "OBC":
        obc_util_kwargs = {"rows_in_parallel": args.rows_in_parallel}

    # create pruner
    pruner = Pruner(
        model,
        data_loader=data_loader,
        module_regex=args.module_regex,
        weights_orig={},
        pruning_method=args.pruning_method,
        rel_damp=args.rel_damp,
        obc_util_kwargs=obc_util_kwargs,
        sequential=args.sequential,
        device=device,
        cpu_offload=True,
        blocks=args.decoder_blocks,
        pre_modules=args.pre_decoder_modules,
        max_samples=args.calibration_dataset_size
    )

    if args.sparsity == 0:
        args.iterations = 1

    eval_stats = {}
    print(f'{args.output_dir=}, {args.alpha=}')
    for i in range(args.iterations):
        print(f"Iteration {i}/{args.iterations} | samples={args.calibration_dataset_size} | {args.model_name_or_path}")
        if args.sparsity > 0.0:
            pruner.prune(args.sparsity, args.alpha)
        print('---Evaluation after pruning---')

        for eval_dataset_name in ['wikitext2', 'c4']:
            # for eval_dataset_name in ['wikitext2']:
            test_data = get_loaders(
                eval_dataset_name,
                0,
                args.seed,
                args.sequence_length,
                True,
                tokenizer
            )
            test_loader = [([], {'input_ids': input_ids}) for input_ids in test_data]
            ppl = eval_perplexity(
                model,
                test_loader,
                args.decoder_blocks,
                args.pre_decoder_modules,
                args.post_decoder_modules,
                device,
                cpu_offload=True
            )
            print(f'Dataset: {eval_dataset_name}\nPerplexity: {ppl:.2f}')
            eval_stats[f'eval/iteration_{i}/{eval_dataset_name}'] = ppl

            if args.output_dir is not None:
                torch.save(eval_stats, os.path.join(args.output_dir, 'eval_results.pth'))

        # Finetune on calibration set
        if i + 1 < args.iterations:
            finetune(model, train_loader, args)

    if args.output_dir is not None:
        torch.save(eval_stats, os.path.join(args.output_dir, 'eval_results.pth'))


if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover