import copy
from pathlib import Path
import argparse

from utils.logs import set_logger, SharedLogger
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'  # Set to your preferred GPU
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
import torch
torch.backends.cuda.matmul.allow_tf32 = True  # A100: enables TF32 matmuls
torch.set_float32_matmul_precision("high")  # A100: enables high precision matmuls
# torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=False)
from torch import nn
from torch.utils.data import Dataset, Subset, DataLoader
from torchinfo import summary

from prune.prune import prune
from utils.parse_configs import parse_config
from utils.base_arguments import BaseArguments
from utils.dataset_utils import CalibrationDatasetArguments, get_calibration_dataloader, get_hybrid_calibration_dataset
from utils.model_utils import ModelArguments, load_model_and_tokenizer
from utils.pruning_utils import PruningArguments

from huggingface_hub import login
# Note: Replace with your own HuggingFace token for model access
# login(token="YOUR_HUGGINGFACE_TOKEN_HERE")

def main():
    # parse configs
    base_args, model_args, data_args, pruning_args, evaluation_args = parse_config()
    
    # Setup global cache directory for all HuggingFace downloads
    cache_dir = None
    models_cache_dir = None
    datasets_cache_dir = None
    
    if base_args.cache_dir is not None:
        cache_dir = Path(base_args.cache_dir).resolve()
        cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Set HuggingFace environment variables to use our cache directory
        os.environ["HF_HOME"] = str(cache_dir)
        os.environ["TRANSFORMERS_CACHE"] = str(cache_dir / "transformers")
        os.environ["HF_DATASETS_CACHE"] = str(cache_dir / "datasets")
        
        models_cache_dir = str(cache_dir / "transformers")
        datasets_cache_dir = str(cache_dir / "datasets")
    


    # set output directory
    base_output_dir = Path(base_args.output_dir) / f"{model_args.model_name_or_path}_{data_args.dataset_name}_{pruning_args.inference_speedup}"
    if base_args.overwrite_output and base_output_dir.exists():
        import shutil
        shutil.rmtree(base_output_dir)
    base_output_dir.mkdir(parents=True, exist_ok=True)

    # create logger to log messages
    log_output = base_output_dir / 'logs'
    log_output.mkdir(parents=True, exist_ok=True)
    SharedLogger.configure(log_output)
    logger = SharedLogger.get_logger("Main")
    logger.info(f'Starting Pruning with config: ')
    logger.info(f'Base Arguments: {base_args}')
    logger.info(f'Model Arguments: {model_args}')
    logger.info(f'Dataset Arguments: {data_args}')
    logger.info(f'Pruning Arguments: {pruning_args}')
    logger.info(f'Output Directory: {base_output_dir}')
    if cache_dir is not None:
        logger.info(f'Cache Directory: {cache_dir}')
        logger.info(f'  - Models: {cache_dir / "transformers"}')
        logger.info(f'  - Datasets: {cache_dir / "datasets"}')
    else:
        logger.info('Cache Directory: Using HuggingFace default cache locations')
    
    # Load Model and Processor
    logger.info('Loading Model and Tokenizer')
    model_data = load_model_and_tokenizer(logger, model_args, base_args.device, cache_dir=models_cache_dir)
    logger.info(f'Teacher Model: {model_data.teacher_model}')
    logger.info(f'Config: {model_data.config}')
    logger.info(f'Tokenizer: {model_data.tokenizer}')
    logger.info(f'Student Model: {model_data.student_model}')
    
    # Load Calibration Data for pruning
    logger.info(f'Loading Calibration of size {data_args.num_samples} from {data_args.dataset_name}/{data_args.dataset_config_name}/{data_args.split}')
    calibration_dataloader = get_hybrid_calibration_dataset(logger, model_data.tokenizer, data_args, base_output_dir, cache_dir=datasets_cache_dir)
    logger.info(f'Loading Calibration dataloader with batch size {data_args.batch_size} from {Path(data_args.output_path) / "calib_dataset.pt"}')
    calibration_dataloader = get_calibration_dataloader(logger, str(Path(base_output_dir / data_args.output_path) / 'calib_dataset.pt'), data_args.batch_size, num_workers=base_args.num_workers)
    logger.info('Calibration dataset loaded')

    ################################################################################### Pruning
    pruning_output = base_output_dir / 'Pruning'
    logger.info(f'Generating Pruning results in {pruning_output}')
    pruning_output.mkdir(parents=True, exist_ok=True)

    # Start pruning
    logger.info('Starting Pruning')

    tick = time.time()

    pruned_model = prune(model_data, pruning_args, calibration_dataloader, base_args.device)
    
    elapsed = time.time() - tick

    logger.info('Finished Pruning')
    logger.info(f'Pruning took {elapsed / 60:.2f} minutes ({elapsed:.2f} seconds)')
    # logger.info(f'Pruned Model: {summary(pruned_model, input_size=(1, 160000))}')
    logger.info(f'Pruned Model: {model_data.student_model}')
    logger.info(f'Saving pruned model to {pruning_output / "pruned_model.bin"}')
    torch.save(pruned_model.state_dict(), pruning_output / "pruned_model.bin")


    del calibration_dataloader

    
    return



if __name__ == '__main__':
    main()
