import torch
import numpy as np
import argparse
import os
import time
import model.transformer as tf
from torch import nn
import utilities.data_generation as dg
import utilities.wikitext_data as wikitext
import grokking_experiments.addition_function as add
from utilities.logger import Logger
import sys
from torch.utils.data import DataLoader, Dataset
import noise_stability.measure_noise_stability as ns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create dataset and dataloader
class AdditionDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.attention_mask = torch.ones_like(X, dtype=torch.float)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.X[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.y[idx]
        }

# Example usage
# Add this function right after your imports and before the existing code:

def main():
    """Main function that handles both regular and DDP execution."""
    
    # Parse arguments first
    parse = argparse.ArgumentParser(description="Grokking Experiment")
    parse.add_argument("--task", type=str, default="addition", 
                        choices=["addition", "wikitext"],
                        help="Task to run: 'addition' for modular addition, 'wikitext' for language modeling")
    parse.add_argument("--d", type=int, default=18, 
                        help="The embedding model dimension.")
    parse.add_argument("--epochs", type=int, default=20, 
                        help="How many epochs to run training for.")
    parse.add_argument("--lr", type=float, default=0.001, 
                        help="The learning rate to set.")
    parse.add_argument("--layers", type=int, default=2, 
                        help="The number of layers to add.")
    parse.add_argument("--heads", type=int, default=2, 
                        help="The number of heads to use")
    parse.add_argument("--train_samples", type=int, default=5000, 
                        help="Size of training dataset.")
    parse.add_argument("--batch_size", type=int, default=32, 
                        help="Set the batch size (default = 32)")
    parse.add_argument("--num_seeds", type=int, default=5, 
                        help="Number of seeds to run the experiment with.")
    parse.add_argument("--noise_reg", type=float, default=0.0, 
                        help="Strength of noise regularization (default = 0.0)")
    parse.add_argument("--noise_reg_r", type=float, default=0.05, 
                        help="Regularization parameter for noise regularization (default = 0.05)")
    parse.add_argument("--rho_list", type=float, nargs='+', default=[], 
                        help="List of rho values to experiment with")
    parse.add_argument("--no_rho", type=bool, default=False, 
                        help="Whether to disable rho (default = False)")
    parse.add_argument("--K", type=int, default=113, 
                        help="The modulo value for the addition function.")
    parse.add_argument("--weight_decay", type=float, default=0.01, 
                        help="Weight decay for the optimizer (default = 0.01, balance for grokking)")
    parse.add_argument("--patience", type=int, default=10, 
                        help="Early stopping patience (default = 10)")
    parse.add_argument("--lr_factor", type=float, default=0.5, 
                        help="Learning rate reduction factor (default = 0.5)")
    parse.add_argument("--training_mode", type=str, default="single", 
                        choices=["single", "dataparallel", "ddp"],
                        help="Training mode: 'single' for single GPU, 'dataparallel' for DataParallel, 'ddp' for DistributedDataParallel")
    parse.add_argument("--gpu", type=int, default=None,
                        help="GPU ID to use for single GPU training (e.g., --gpu 0, --gpu 1, etc.). If not specified, uses GPU 0.")
    parse.add_argument("--gpu_ids", type=int, nargs='+', default=None,
                        help="Specific GPU IDs to use for multi-GPU modes (e.g., --gpu_ids 0 1 2 3). If not specified, uses all available GPUs.")
    
    # WikiText-specific arguments
    parse.add_argument("--vocab_size", type=int, default=5000,
                        help="Vocabulary size for WikiText (default = 5000)")
    parse.add_argument("--seq_length", type=int, default=32,
                        help="Sequence length for WikiText (default = 32)")
    
    # Performance arguments
    parse.add_argument("--epoch_period", type=int, default=1,
                        help="How often to validate and measure metrics (default = 1, set higher for faster training)")
    parse.add_argument("--label_smoothing", type=float, default=0.0,
                        help="Label smoothing for better generalization (default = 0.0, try 0.1 if overfitting)")
    parse.add_argument("--dropout", type=float, default=0.0,
                        help="Dropout rate (default = 0.0, try 0.1-0.2 for more regularization)")
    
    args = parse.parse_args()
    
    # Determine training configuration based on arguments
    is_main_process = True
    use_dp = False
    use_ddp = False
    gpu_ids = None

    if args.training_mode == "dataparallel":

        # DataParallel mode.
        use_dp = True
        gpu_ids = args.gpu_ids if args.gpu_ids else list(range(torch.cuda.device_count()))
        
    elif args.training_mode == "ddp":

        # DDP mode.
        if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
            rank = int(os.environ['RANK'])
            world_size = int(os.environ['WORLD_SIZE'])
            is_main_process = (rank == 0)
            use_ddp = True  # Override since we're in DDP environment
            print(f"DDP Process {rank}/{world_size}")
            
            # Initialize DDP - this should be done once per process
            tf.setup_ddp(rank, world_size)
        else:
            assert False, "DDP mode requires external launch with torchrun or torch.distributed.launch. "

    elif args.training_mode != "single":
        assert False, "Invalid training mode"
    
    # Only main process handles folder creation and logging setup
    if is_main_process:
        # Create a folder for the results to be stored in.
        folder_name = time.strftime("%Y%m%d-%H%M%S")
        folder_name = f"plots/{folder_name}"
        os.makedirs(folder_name, exist_ok=True)
        
        log_file = os.path.join(folder_name, "script_output.log")
        # Redirect stdout to our logger.
        sys.stdout = Logger(log_file)
        
        print("--START--")
        print(f"Task: {args.task}")
        print(f"Embedding Dimension: {args.d}")
        print(f"Batch size: {args.batch_size}")
        print(f"Learning Rate: {args.lr}")
        print(f"Epochs: {args.epochs}")
        print(f"Weight Decay: {args.weight_decay}")
        print(f"Early Stopping Patience: {args.patience}")
        print(f"Learning Rate Reduction Factor: {args.lr_factor}")
        print(f"Layers: {args.layers}")
        print(f"Attention heads: {args.heads}")
        print(f"Number of training examples: {args.train_samples}")
        print(f"Noise Regularization Strength: {args.noise_reg}")
        print(f"Noise Regularization Parameter: {args.noise_reg_r}")
        print(f"Rho values: {args.rho_list}")
        print(f"Training Mode: {args.training_mode}")
        if args.training_mode == "single" and args.gpu is not None:
            print(f"GPU ID: {args.gpu}")
        elif gpu_ids != None: 
            print(f"GPU IDs: {gpu_ids}")
        if args.task == "wikitext":
            print(f"Vocabulary Size: {args.vocab_size}")
            print(f"Sequence Length: {args.seq_length}")
        print("---------")
    else:
        # Non-main processes use a simple folder name (will be created by main process)
        folder_name = f"plots/{time.strftime('%Y%m%d-%H%M%S')}"
    
    # Parameters
    batch_size = args.batch_size
    num_epochs = args.epochs
    
    # Device selection
    if torch.cuda.is_available():
        if args.training_mode == "single" and args.gpu is not None:
            device = f'cuda:{args.gpu}'
            if is_main_process:
                print(f"Using GPU {args.gpu}")
        else:
            device = 'cuda'
    else:
        device = 'cpu'
        if is_main_process:
            print("CUDA not available, using CPU")

    # Task-specific setup
    if args.task == "addition":
        # Generate training and validation data for addition
        if is_main_process:
            print("Generating training and validation data for modular addition...")
        
        train_X, train_y = dg.generate_data_addition(args.train_samples, args.K)
        val_X, val_y = dg.generate_data_addition(200, args.K)
        test_X, test_y = dg.generate_data_addition(200, args.K)

        # Create DataLoaders
        train_dataset = AdditionDataset(train_X, train_y)
        val_dataset = AdditionDataset(val_X, val_y)
        test_dataset = AdditionDataset(test_X, test_y)

        if is_main_process:
            print(f"Training data shape: {train_X.shape}, Labels shape: {train_y.shape}")
            print(f"Validation data shape: {val_X.shape}, Labels shape: {val_y.shape}")
            print(f"Test data shape: {test_X.shape}, Labels shape: {test_y.shape}")
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                                 num_workers=4, pin_memory=True if device.startswith('cuda') else False)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                               num_workers=2, pin_memory=True if device.startswith('cuda') else False)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                                num_workers=2, pin_memory=True if device.startswith('cuda') else False)
        
        # Model arguments for addition task
        vocab_size_actual = args.K + 5
        model_args = {
            'vocab_size': vocab_size_actual,
            'd_model': args.d,
            'n_layers': args.layers,
            'n_heads': args.heads,
            'num_classes': args.K,
            'task_type': 'classification',
            'dropout_rate': args.dropout
        }
        task_type = 'classification'
        input_length = 2
        
    elif args.task == "wikitext":
        # Load WikiText data
        if is_main_process:
            print("Loading WikiText-2 data for language modeling...")
        
        tokenizer, train_loader, val_loader, test_loader, vocab_size_actual = wikitext.get_wikitext_experiment_setup(
            vocab_size=args.vocab_size,
            seq_length=args.seq_length,
            train_samples=args.train_samples,
            val_samples=200,
            test_samples=200,
            batch_size=batch_size
        )
        
        # Model arguments for language modeling task
        model_args = {
            'vocab_size': vocab_size_actual,
            'd_model': args.d,
            'n_layers': args.layers,
            'n_heads': args.heads,
            'max_length': args.seq_length,
            'task_type': 'language_modeling',
            'dropout_rate': args.dropout
        }
        task_type = 'language_modeling'
        input_length = args.seq_length
    
    else:
        raise ValueError(f"Unknown task: {args.task}")

    rho = args.rho_list
    if args.no_rho:
        rho = []
    
    # Define training arguments
    train_kwargs = {
        'lr': args.lr,
        'device': device,
        'weight_decay': args.weight_decay,
        'patience': args.patience,
        'lr_factor': args.lr_factor,
        'rho': rho,
        'input_length': input_length,
        'label_smoothing': args.label_smoothing
    }
    
    # Generate seeds (same across all processes for consistency)
    np.random.seed(42)  # Ensure same seeds across processes
    seeds = np.random.randint(0, 10000, size=args.num_seeds).tolist()
    
    # Run with multiple seeds - this now handles DDP properly
    results = tf.run_multiple_seeds(
        model_class=tf.SimpleTransformer,
        model_args=model_args,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        folder_name=folder_name,
        vocab_size=vocab_size_actual,
        seeds=seeds,
        noise_reg_strength=args.noise_reg,
        noise_reg_r=args.noise_reg_r,
        learn_function_stabilities=None,
        use_dp=use_dp,
        use_ddp=use_ddp,
        gpu_ids=gpu_ids,
        task_type=task_type,
        epoch_period=args.epoch_period,
        **train_kwargs
    )
    
    # Only main process prints final results
    if is_main_process:
        test_accuracies = results['test_accuracies']
        print(f"Test Accuracies: {test_accuracies}")
        print(f"Mean: {np.mean(test_accuracies):.2f}%, Std: {np.std(test_accuracies):.2f}%")
    
    # Cleanup DDP if we used it
    if use_ddp:
        tf.cleanup_ddp()


if __name__ == "__main__":
    main()