import torch
import numpy as np
import argparse
import os
import time
import model.transformer as tf
import utilities.data_generation as dg
import experiments_with_boolean_functions.boolean_functions as bf
from utilities.logger import Logger
import sys
from torch.utils.data import DataLoader, Dataset
import noise_stability.measure_noise_stability as ns

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create dataset and dataloader
class BooleanDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.attention_mask = torch.ones_like(X, dtype=torch.float)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.X[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.y[idx]
        }

# Example usage
if __name__ == "__main__":
    parse = argparse.ArgumentParser(description="Boolean Functions")
    parse.add_argument("--function", type=str, default="parity", \
                        help="The function to try and learn")
    parse.add_argument("--d", type=int, default="18", \
                        help="The embedding model dimension.")
    parse.add_argument("--epochs", type=int, default="20",\
                        help="How many epochs to run training for.")
    parse.add_argument("--n", type=int, default="50", \
                         help="The input sequence length.")
    parse.add_argument("--lr", type=float, default="0.001", \
                         help="The learning rate to set. ")
    parse.add_argument("--layers", type=int, default="2", \
                         help="The number of layers to add.")
    parse.add_argument("--heads", type=int, default="2", \
                         help="The number of heads to use")
    parse.add_argument("--k", type=int, default=3, \
                       help="The number of relevant junta variables.")
    parse.add_argument("--train_samples", type=int, default=5000, \
                         help="Size of training dataset.")
    parse.add_argument("--batch_size", type=int, default=32, \
                        help="Set the batch size (default = 32)")
    parse.add_argument("--num_seeds", type=int, default=5, \
                        help="Number of seeds to run the experiment with.")
    parse.add_argument("--noise_reg", type=float, default=0.0, \
                        help="Strength of noise regularization (default = 0.0)")
    parse.add_argument("--noise_reg_r", type=float, default=0.05, \
                        help="Regularization parameter for noise regularization (default = 0.05)")
    parse.add_argument("--patience", type=int, default=5, \
                        help="Early stopping patience (default = 5)")
    parse.add_argument("--lr_factor", type=float, default=0.5, \
                        help="Learning rate reduction factor (default = 0.5)")  
    parse.add_argument("--rho_list", type=float, nargs='+', default=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], \
                        help="List of rho values to experiment with (default = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5])")

    # Parse arguments.
    args = parse.parse_args()

    # Create a folder for the results to be stored in.
    folder_name = time.strftime("%Y%m%d-%H%M%S")
    folder_name = f"plots/{folder_name}"

    os.makedirs(folder_name, exist_ok=True)

    log_file = os.path.join(folder_name, "script_output.log")

    # Redirect stdout to our logger.
    sys.stdout = Logger(log_file)

    print("--START--")
    print(f"Function to learn: {args.function}")
    print(f"Sequence Length: {args.n}")
    print(f"Embedding Dimension: {args.d}")
    print(f"Batch size: {args.batch_size}")
    print(f"Learning Rate: {args.lr}")
    print(f"Epochs: {args.epochs}")
    print(f"Layers: {args.layers}")
    print(f"Attention heads: {args.heads}")
    print(f"Number of training examples: {args.train_samples}")
    print(f"Noise Regularization Strength: {args.noise_reg}")
    print(f"Noise Regularization Parameter: {args.noise_reg_r}")
    print(f"Early Stopping Patience: {args.patience}")
    print(f"Learning Rate Reduction Factor: {args.lr_factor}")
    print(f"Rho values: {args.rho_list}")
    print("---------")

    # Parameters
    batch_size = args.batch_size
    seq_length = args.n
    d_model = args.d
    num_epochs = args.epochs
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    learn_function = None
    if args.function == "majority":
        learn_function = bf.majority
    elif args.function == "parity":
        learn_function = bf.parity 

    relevant_coords = None
    if args.function == "junta":
        learn_function, relevant_coords = bf.create_junta_parity(seq_length, args.k)
        print(f"Trying to learn a junta on the following variables: {relevant_coords}")

    # Generate training and validation data
    print("Generating training and validation data...")
    train_X, train_y = dg.generate_data(args.train_samples, seq_length, learn_function,relevant_coords)
    val_X, val_y = dg.generate_data(200, seq_length, learn_function, relevant_coords)
    test_X, test_y = dg.generate_data(200, seq_length, learn_function, relevant_coords)
    
    # Create DataLoaders
    train_dataset = BooleanDataset(train_X, train_y)
    val_dataset = BooleanDataset(val_X, val_y)
    test_dataset = BooleanDataset(test_X, test_y)

    print(f"Training data shape: {train_X.shape}, Labels shape: {train_y.shape}")
    print(f"Validation data shape: {val_X.shape}, Labels shape: {val_y.shape}")
    print(f"Test data shape: {test_X.shape}, Labels shape: {test_y.shape}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    rho = args.rho_list

    # Calculate the 
    benchmark_stabilities = {}
    for r in rho:
        benchmark_stabilities[r] = ns.measure_noise_stability_of_function(
            learn_function, 
            args.n, 
            r, 
            num_trials=10000, 
            device=device, 
            relevant_coords=relevant_coords)
    
    # Define model arguments
    model_args = {
        'vocab_size': 2,
        'd_model': args.d,
        'n_layers': args.layers,
        'n_heads': args.heads
    }
    
    # Define training arguments
    train_kwargs = {
        'lr': args.lr,
        'device': device,
        'weight_decay': 0,
        'patience': args.patience,  # Early stopping patience
        'lr_factor': args.lr_factor, # Learning rate reduction factor
        'rho': rho,
        'input_length': args.n
    }
    
    # Run with multiple seeds
    seeds = np.random.randint(0, 10000, size=args.num_seeds).tolist()
    
    results = tf.run_multiple_seeds(
        model_class=tf.SimpleTransformer,
        model_args=model_args,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        num_epochs=num_epochs,
        folder_name=folder_name,
        vocab_size=2,
        seeds=seeds,
        noise_reg_strength=args.noise_reg,
        noise_reg_r=args.noise_reg_r,
        learn_function_stabilities=benchmark_stabilities,
        epoch_period=1,
        **train_kwargs
    )
    
    # Print final results
    test_accuracies = results['test_accuracies']
    print(f"Test Accuracies: {test_accuracies}")
    print(f"Mean: {np.mean(test_accuracies):.2f}%, Std: {np.std(test_accuracies):.2f}%")