import pickle
import sys
sys.path.append("../")
import os
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch
import re 
from g2p_en import G2p
import numpy as np
from model_training.dataset import getDatasetLoaders
from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_V2
import time
import numpy as np
from edit_distance import SequenceMatcher
import tqdm
import pytorch_lightning as pl
import jiwer
import nltk
from nltk.corpus import cmudict
from pytorch_lightning.loggers import WandbLogger
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import copy
from difflib import get_close_matches
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer
import pandas as pd
from torchaudio.models.decoder import ctc_decoder
import string
import os
from torch.utils.data import Subset, DataLoader

#import seed_everything
from pytorch_lightning import seed_everything

# from model.ctc_modelling import Light

# Download CMU Pronouncing Dictionary (First-time use)
nltk.download("cmudict")

# Load CMUdict
cmu_dict = cmudict.dict()

OUTPUT_NAME = "gru_ctc"

train_loader, val_loader, test_loader = getDatasetLoaders(BATCH_SIZE=64, SHUFFLE_TRAIN=True)

nInputFeatures = 512 #channels 
nClasses = 40 
dropout = 0.4 
hidden_dim = 1024
nlayers = 5
stride_len = 4
kernel_len = 16
gaussian_smooth_width = 2
bidirectional = True

white_noise_SD = 0.8
constant_offset_SD = 0.2
seq_len = 150
max_time_series_len = 12000

lr_start = 1e-4
lr_end = 1e-5
l2_decay = 1e-5


warmup_epoch = 5
steps_per_epoch = len(train_loader)
warmup_steps = warmup_epoch * steps_per_epoch

target_epoch = 60
total_steps = target_epoch * steps_per_epoch

sweep_config = {
    "method": "random",  # Options: "random", "grid", "bayes"
    "metric": {"name": "val_CER", "goal": "minimize"},
    "parameters": {
        "learning_rate": {"values" :[1e-4,1e-5,5e-5]},
        "weight_decay": {"values": [0.,1e-5, 1e-4,1e-2]},
        "dropout": {"values": [0., 0.2, 0.3, 0.4]},
        "hidden_dim": {"values": [512, 1024, 2048]},
        "nlayers": {"values": [2, 3, 4, 5, 6]},
        "stride_len": {"values": [2, 4, 8]},
        "kernel_len": {"values": [8, 16, 32]},
        "gaussian_smooth_width": {"values": [1, 2, 3, 4]},
        "bidirectional": {"values": [True, False]},
        "white_noise_SD": {"values": [0., 0.2, 0.5, 0.8, 1.0]},
        "constant_offset_SD": {"values": [0., 0.1, 0.2, 0.5,]}
    }
}



def train():
    """Training function to be executed by WandB Sweeps."""
    
    # Initialize a new WandB run
    wandb.init()
    config = wandb.config  # Retrieve hyperparameters

    model = LightningGRUDecoder_V2(
                neural_dim=nInputFeatures,
                n_classes=nClasses,
                hidden_dim=config.hidden_dim,
                layer_dim=config.nlayers,
                nDays=45,
                strideLen=config.stride_len,
                kernelLen=config.kernel_len,
                gaussianSmoothWidth=config.gaussian_smooth_width,
                bidirectional=config.bidirectional,
                dropout=config.dropout,
                white_noise_SD=config.white_noise_SD,
                constant_offset_SD=config.constant_offset_SD,
                weight_decay=config.weight_decay,
                learning_rate=config.learning_rate,)
    # Create a directory to save models
    run_folder = f"{OUTPUT_NAME}/{wandb.run.name}"
    os.makedirs(run_folder, exist_ok=True)

    wandb_logger = WandbLogger(project="B2TXT25", name=f"{OUTPUT_NAME}",
                            reinit=True)

    # Define ModelCheckpoint to save the best model based on validation loss
    checkpoint_callback_loss = ModelCheckpoint(
        monitor="val_loss",  # Ensure your validation step logs "val_loss"
        mode="min",          # Save the model with the lowest validation loss
        save_top_k=1,        # Keep only the best model
        dirpath=f"{run_folder}/",  # Directory to save checkpoints
        filename=f"best_model_loss",  # Model filename
        verbose=True
    )

        # Define ModelCheckpoint to save the best model based on validation loss
    checkpoint_callback_per = ModelCheckpoint(
        monitor="val_CER",  # Ensure your validation step logs "val_loss"
        mode="min",          # Save the model with the lowest validation loss
        save_top_k=1,        # Keep only the best model
        dirpath=f"{run_folder}/",  # Directory to save checkpoints
        filename=f"best_model_per",  # Model filename
        verbose=True
    )

    # Define EarlyStopping callback with patience of 3 epochs
    early_stopping_callback = EarlyStopping(
        monitor="val_loss",
        patience=25,   # Stop training if no improvement in 3 epochs
        mode="min",
        verbose=True
    )


    # Train model
    trainer = pl.Trainer(max_epochs=100,devices =[0], callbacks=[checkpoint_callback_loss,checkpoint_callback_per, early_stopping_callback], logger=wandb_logger)

    trainer.fit(model, train_loader, val_loader)

    wandb.finish()  # Finish the WandB run


sweep_id = wandb.sweep(sweep_config, project="B2TXT25")
wandb.agent(sweep_id, function=train,count=100)  # Run 100 experiments

