import argparse
import json
import random
import yaml

import torch
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms

from simulator.functions import TransformerNextTokenPredictionFunction
from simulator.fixed_time_utils import run_pipeline
from simulator.algorithms.ringmaster_sgd import RingmasterSGDServer
from simulator.algorithms.rennala_sgd import RennalaSGDServer, RennalaSGDWorker
from simulator.algorithms.synchronized_sgd import SynchronizedSGDServer
from simulator.worker import Worker, WorkerWithLocalSteps, WorkerWithTargetComputeCommunicateRatio

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2TokenizerFast, DataCollatorForLanguageModeling


def get_wikitext_dataset(tokenizer, model_type='gpt2'):
    """Load and prepare WikiText-2 dataset with Hugging Face datasets"""
    # Load WikiText-2 dataset
    datasets = load_dataset("wikitext", "wikitext-2-v1")
    
    # Define tokenization function
    def tokenize_function(examples):
        # For GPT2 style models, we don't need special tokens
        if model_type == 'gpt2':
            return tokenizer(examples["text"], truncation=True, max_length=128)
        # For BERT style models, we use special tokens
        else:
            return tokenizer(examples["text"], truncation=True, max_length=128, 
                            padding="max_length", return_special_tokens_mask=True)
    
    # Apply tokenization to dataset
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        num_proc=4,
        remove_columns=["text"],
    )
    
    return tokenized_datasets

def get_wikitext_dataloaders(config):
     # Load tokenizer (same for all processes)
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    # Get dataset
    tokenized_datasets = get_wikitext_dataset(tokenizer, model_type='gpt2')
    
    # Create data collator for language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,  # GPT-2 uses causal language modeling, not masked
        pad_to_multiple_of=8  # Add padding to ensure no empty tensors
    )
    
    # Create dataloaders with distributed samplers
    train_dataloader = DataLoader(
        tokenized_datasets["train"],
        batch_size=config['batch_size'],
        collate_fn=data_collator,
        drop_last=True,  # Avoid empty batches
    )
    
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"],
        batch_size=config['batch_size'],
        collate_fn=data_collator,
        drop_last=True,  # Avoid empty batches
    )

    return train_dataloader, eval_dataloader

def train_next_token_prediction(train_dataloader, test_dataloader, save_path, config, config_name):
    print("Dataset shape:", len(train_dataloader.dataset))
    print("Test dataset shape:", len(test_dataloader.dataset))

    # Hyperparameters
    reg = 0.0
    batch_size = config['batch_size']
    gamma = config['gamma']
    sim_time = config['sim_time']
    num_workers = config['num_workers']
    metric_check_num = config['metric_check_num']
    metric_check_period = sim_time / metric_check_num
    times_to_calculate = config['times_to_calculate']
    times_to_communicate = config['times_to_communicate']
    local_steps = config.get('local_steps', None)
    
    num_of_gpus = torch.cuda.device_count()
    print(f"Devices available: {num_of_gpus}")
    functions = [
        TransformerNextTokenPredictionFunction(
            train_dataloader, test_dataloader, batch_size=batch_size, device=f"cuda:{w_id % num_of_gpus}"
        )
        for w_id in range(num_workers)
    ]
    
    # function, which parameters will be used and updated on server side
    function_test = TransformerNextTokenPredictionFunction(
        train_dataloader, test_dataloader, batch_size=batch_size, device=f"cuda:{random.randint(0, num_of_gpus - 1)}"
    )
    
    # point is model mow, as othervise will be problem with .parameters() and optimizer
    point = function_test.model

    metric_checked_times = 0
    class Metric:
        def __init__(self, metric_check_period, saving_period = 20):
            self._metric_check_period = metric_check_period
            self._metric_checked_times = metric_checked_times
            self._saving_period = saving_period
            
        def calculate_metrics(self, env, iter, point):
            if env.now < metric_check_period * self._metric_checked_times:
                return None, None, None
            else:
                self._metric_checked_times += 1
            
            # pass None, as function tests reference the same model parameters as server
            # loss_train = function_test.value(None)
            loss_train, perplexity_train = function_test.perplexity(None, train=True)
            loss_test, perplexity_test = function_test.perplexity(None)
            
            save_path_temporal = save_path.replace(".json", "_temp.json") if (self._metric_checked_times + 1) % self._saving_period == 0 else None
            print(f"Time {env.now}, Loss: {loss_train}, Perplexity: {perplexity_train}, Loss Test: {loss_test}, Perplexity Test: {perplexity_test}")
            return {'value': float(loss_train), 'time': env.now, "iter": iter, 
                    'loss': loss_test, 'perplexity': perplexity_train, 'perplexity_test': perplexity_test}, save_path_temporal, config
    
    if config['server'] == 'ringmaster_sgd':
        server_cls = RingmasterSGDServer
        worker_cls = Worker
    elif config['server'] == 'rennala_sgd':
        server_cls = RennalaSGDServer
        worker_cls = RennalaSGDWorker
    elif config['server'] == 'synchronized_sgd':
        server_cls = SynchronizedSGDServer
        worker_cls = Worker
    elif config['server'] == 'local_sgd':
        server_cls = RennalaSGDServer
        worker_cls = WorkerWithLocalSteps
    elif config['server'] == 'ringmaster_sgd_compcomm':
        server_cls = RingmasterSGDServer
        worker_cls = WorkerWithTargetComputeCommunicateRatio
    else:
        raise RuntimeError()
    
    if config['optimizer'] == 'adamw':
        optimizer_cls = torch.optim.AdamW
    elif config['optimizer'] == 'adam':
        optimizer_cls = torch.optim.Adam
    elif config['optimizer'] == 'sgd':
        optimizer_cls = torch.optim.SGD
    else:
        raise RuntimeError()
    
    print(f"Experiment name: {config_name}, Server params: {config['server_params']}, worker params: {config['worker_params']}")
    _, stats = run_pipeline(server_cls, worker_cls,
                            functions, point, gamma, optimizer_cls, sim_time=sim_time,
                            times_to_calculate=times_to_calculate,
                            times_to_communicate=times_to_communicate,
                            server_params=config['server_params'],
                            worker_params=config.get('worker_params', {}),
                            calculate_metrics=Metric(metric_check_period).calculate_metrics,
                            local_steps = local_steps)
    
    stats["params"] = config
    with open(save_path, 'w') as f:
        json.dump(stats, f, indent=4)
    print(f"Results saved to {save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Logistic Regression on MNIST (binary classification) and save experiment results.")
    parser.add_argument('--save_path', type=str, required=True, help='Path to save experiment results')
    parser.add_argument('--config', type=str, required=True, help='Config with params')
    args = parser.parse_args()

    config = yaml.safe_load(open(args.config))
    train_dataloader, test_dataloader = get_wikitext_dataloaders(config)
    train_next_token_prediction(train_dataloader, test_dataloader, save_path=args.save_path, config=config, config_name=args.config)
    print("Training completed.")
