import yaml
import argparse
import torch
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, AutoTokenizer
from transformers import get_cosine_schedule_with_warmup
from datasets import load_dataset, Dataset
import matplotlib.pyplot as plt
from iams.llms.utils import compute_cross_entropy_loss, get_default_config, merge_configs, load_config, get_outputfile_from_configfile
from iams.llms.data import load_data
from iams.llms.train import train, get_scheduler
from iams.utils import smoothen_dict
import copy 
import json



def main(config_file=None):
    # Default parameters if no config file is provided
    default_config = get_default_config() 
    if config_file:
        config = load_config(default_config, config_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    train_dataloader, test_dataloader = load_data(config['dataset']['name'], batch_size=config['training_params']['batch_size'])

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config['gpt_model']['teacher_model'])
    tokenizer.pad_token = tokenizer.eos_token

    # Set the training parameters
    training_params = config['training_params'] 
    print(f"Training with teacher model {config['gpt_model']['teacher_model']} on dataset {config['dataset']['name']}")

    total_tokens = 0  # Initialize token count
    # Iterate over the DataLoader
    for batch in train_dataloader:
        input_text = batch['text']
        
        # Tokenize the input text
        inputs = tokenizer(
            input_text,
            return_tensors='pt',
            max_length=training_params['max_length'],
            truncation=True,
            padding='max_length'
        )
        # Count the number of tokens in this batch
        total_tokens += inputs['input_ids'].ne(tokenizer.pad_token_id).sum().item()

    print(f"Total number of tokens in the dataset: {total_tokens}")
   
    outputfile = get_outputfile_from_configfile(config_file) 
    with open(outputfile+'token_count', 'w') as file: json.dump(total_tokens, file)
   
if __name__ == "__main__":
    # Argument parser to optionally provide a config file
    parser = argparse.ArgumentParser(description='Train GPT-2 with optional config file.')
    parser.add_argument('--config', type=str, help='Path to config file', default=None)
    
    args = parser.parse_args()
    if args.config:
        print(f"Loading configuration from {args.config}")
    else:
        print("No config file provided, using default settings.")
    main(args.config)
    


