import argparse
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle

from pprint import pprint
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer

from ds_src.initialize.dataset import initialize_dataset
from ds_src.utils import get_yaml_config


def get_args() -> argparse.Namespace:
    """ Get the arguments from the command line.

    Returns:
        argparse.Namespace: The arguments from the command line.
    """

    parser = argparse.ArgumentParser()

    parser.add_argument(
        '-c',
        '--config',
        type=str,
        required=True,
        help='The path to the task config file.',
    )
    parser.add_argument(
        '-mnop',
        '--model_name_or_path',
        type=str,
        required=True,
        help='The name or path of the model.',
    )
    parser.add_argument(
        '-m',
        '--mode',
        type=str,
        required=True,
        help='The mode of the task.',
    )
    parser.add_argument(
        '-bs',
        '--batch_size',
        default=1,
        type=int,
        required=False,
        help='The batch size.',
    )
    parser.add_argument(
        '-sl',
        '--sequence_length',
        default=2048,
        type=int,
        required=False,
        help='The sequence length of the model.',
    )
    parser.add_argument(
        '-dwn',
        '--dataloader_workers_number',
        default=32,
        type=int,
        required=False,
        help='The number of workers for the dataloader.',
    )
    parser.add_argument(
        '-sdt',
        '--save_data_tokens',
        action='store_true',
        required=False,
        help='Whether to save the data tokens to a file.',
    )

    return parser.parse_args()


if __name__ == '__main__':
    """ Count the number of tokens in the dataset and the data number of subsets (only for LaMini dataset).
    """

    args = get_args()

    config = get_yaml_config(
        config_path=args.config,
        version='TOKEN_CALCULATOR',
    )
    dataset = initialize_dataset(
        config=config,
        mode=args.mode,
    )

    ## Check the subset data number (for LaMini and Slimpajama datasets).
    if config[args.mode]['dataset']['name'] != 'c4':
        source_dict = {}
        for data in tqdm(
                iterable=dataset,
                desc='[Counting the data number of each source]',
                dynamic_ncols=True,
        ):
            if data['source'] not in source_dict.keys():
                source_dict[data['source']] = 1
            else:
                source_dict[data['source']] += 1

        print(f'The data number of each source:')
        pprint(source_dict)
    ## -----

    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path=args.model_name_or_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.dataloader_workers_number,
        collate_fn=dataset.get_collate_fn(
            config=config,
            tokenizer=tokenizer,
        ),
    )

    tokens_counter = 0
    actual_tokens_list = []
    overflow_count = 0
    classes = {}

    for index, data in enumerate(tqdm(iterable=dataloader)):
        batch_token_count = data['input_ids'].shape[1]
        tokens_counter += batch_token_count

        for i in range(data['input_ids'].shape[0]):
            actual_tokens = (data['input_ids'][i]
                             != tokenizer.pad_token_id).sum().item()
            actual_tokens_list.append(actual_tokens)

            if actual_tokens >= args.sequence_length:
                overflow_count += 1

                if 'source' in data.keys():
                    for one_source in data['source']:
                        if one_source not in classes.keys():
                            classes[one_source] = 1
                        else:
                            classes[one_source] += 1

    max_tokens = max(actual_tokens_list)
    min_tokens = min(actual_tokens_list)
    avg_tokens = np.mean(actual_tokens_list)

    print(f'The name of the dataset: {config[args.mode]["dataset"]["name"]}')
    print(
        f'There are {tokens_counter:,} tokens in {len(dataloader):,} batches ({(args.batch_size * len(dataloader)):,} data).'
    )
    print(f'Maximum tokens per data: {max_tokens}')
    print(f'Minimum tokens per data: {min_tokens}')
    print(f'Average tokens per data: {avg_tokens:.2f}')
    print(f'Overflow count: {overflow_count}')
    print(f'Classes: {classes}')

    if args.save_data_tokens:
        root_path = 'ds_src/initialize/datasets'
        model_name = args.model_name_or_path.split('/')[-1]
        dataset_name = config[args.mode]['dataset']['name']
        with open(
                file=os.path.join(
                    root_path,
                    f'{model_name}_{dataset_name}_{args.mode}_token_count.pkl',
                ),
                mode='wb',
        ) as f:
            pickle.dump(
                obj=actual_tokens_list,
                file=f,
            )
            f.close()

    plt.figure(figsize=(12, 8))

    # Plot token length distribution.
    plt.subplot(2, 1, 1)
    plt.title(
        label=
        f'Token Length Distribution - {config[args.mode]["dataset"]["name"]}')
    plt.xlabel(xlabel='Number of Tokens')
    plt.ylabel(ylabel='Frequency')
    plt.grid(
        visible=True,
        alpha=0.3,
    )
    plt.hist(
        x=actual_tokens_list,
        bins=50,
        alpha=0.7,
        color='skyblue',
        edgecolor='black',
    )

    # Plot token length box plot.
    plt.subplot(2, 1, 2)
    plt.title(label='Token Length Box Plot')
    plt.xlabel(xlabel='Number of Tokens')
    plt.grid(
        visible=True,
        alpha=0.3,
    )
    plt.boxplot(
        x=actual_tokens_list,
        vert=False,
    )

    plt.tight_layout()

    dataset_name = config[args.mode]['dataset']['name']
    plt.savefig(
        f'token_distribution_{dataset_name}_{args.mode}_{args.sequence_length}.png',
        dpi=300,
        bbox_inches='tight',
    )
