
import torch
import numpy as np
import os
import matplotlib.pyplot as plt

import importlib.util
import sys

import wandb
# import trackio as wandb
# from sklearn.model_selection import train_test_split
from typing import Callable

# Define an annealing scheduler
from torch.optim.lr_scheduler import ExponentialLR


def get_model_class(config, Model_default) -> Callable:
    if config.model_file is not None:
        print(f"Loading model class {config.model} from {config.model_file}")
        
        if os.path.isfile(config.model_file):
            # Load from a specific file
            spec = importlib.util.spec_from_file_location("Model", config.model_file)
            Models = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(Models)
            return getattr(Models, config.model)
            
        elif os.path.isdir(config.model_file):
            # Load from a package directory
            # Normalize the path to handle trailing slashes
            normalized_path = os.path.normpath(os.path.abspath(config.model_file))
            parent_dir = os.path.dirname(normalized_path)
            package_name = os.path.basename(normalized_path)
            
            if not package_name:
                raise ValueError(f"Could not determine package name from path: {config.model_file}")
            
            if parent_dir not in sys.path:
                sys.path.insert(0, parent_dir)
                added_to_path = True
            else:
                added_to_path = False
            
            try:
                # Import the package
                Models = importlib.import_module(package_name)
                # Force reload to ensure we get the latest version
                importlib.reload(Models)
                return getattr(Models, config.model)
            finally:
                # Clean up sys.path if we added to it
                if added_to_path:
                    sys.path.remove(parent_dir)
        else:
            raise FileNotFoundError(f"Model path {config.model_file} does not exist or is neither a file nor directory.")
    else:
        print(f"Loading model class {config.model} from built-in models.")
        return getattr(Model_default, config.model)



import json
def save_data_artifact(model_config, data):
    
        # Get filename for artifact naming
        data_filename = os.path.basename(model_config.data_file)
        # remove special characters from the filename
        data_filename = ''.join(c for c in data_filename if c.isalnum() or c in ('_', '-')).rstrip('.pkl')
        
        # Check if this dataset already exists as an artifact
        try:
            # Try to use the existing artifact if available
            artifact = wandb.use_artifact(f"listops-dataset-{data_filename}:latest", type="dataset")
            print(f"Using existing dataset artifact: {artifact.name}")
        except Exception:
            # If not found, create and log a new artifact
            print(f"Creating new dataset artifact for {data_filename}")
            dataset_artifact = wandb.Artifact(
                name=f"listops-dataset-{data_filename}",
                type="dataset",
                description="ListOps dataset with metadata"
            )
            
            # Add the data file to the artifact
            dataset_artifact.add_file(model_config.data_file)
            
            # Add metadata as JSON
            metadata = data['metadata']
            
            metadata_path = os.path.join(model_config.save_path, f"dataset_metadata_{data_filename}.json")
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f)
            
            dataset_artifact.add_file(metadata_path)
            
            # Log the artifact to wandb - this uploads it once
            artifact = wandb.log_artifact(dataset_artifact)
        
        # Link the artifact to this run - this doesn't upload again, just creates a reference
        wandb.run.use_artifact(artifact)
        
        # Add the artifact name to the config for reference
        wandb.config.update({"dataset_artifact": artifact.name})
        
        
def strip_num(func_name):
    """Strip the number from the function name."""
    if '_' in func_name: 
        return '_'.join(func_name.split('_')[:-1])
    else:
        return func_name

class AnnealingLR(ExponentialLR):
    def __init__(self, optimizer, lr_min, lr_max, epochs, last_epoch=-1):
        self.epochs = epochs
        self.epoch_counter = 0 # we use this to know when to stop decaying the LR
        self.get_decay_rate(lr_min, lr_max, epochs)
        super(AnnealingLR, self).__init__(optimizer, self.gamma, last_epoch)
    
    def get_decay_rate(self, lr_min, lr_max, epochs):
        self.gamma = (lr_min/lr_max)**(1/epochs)
    
    # we can redefine step to change the learning rate only before "epochs" epochs
    def step(self):
        if self.epoch_counter < self.epochs:
            self.epoch_counter += 1
            super(AnnealingLR, self).step()
        else:
            pass

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.reset()
        
    def reset(self, energy=float('inf')):
        self.best_energy = energy
        self.patience_counter = 0
        self.early_stopping_triggered = False
        
    def check_early_stopping(self, energy):
        if energy < self.best_energy - self.min_delta:
            self.reset(energy)
        else:
            self.patience_counter += 1
            self.early_stopping_triggered = False
        if self.patience_counter >= self.patience:
            self.early_stopping_triggered = True
            # print("Early stopping")
            
    def __call__(self, energy):
        self.check_early_stopping(energy)
        return self.early_stopping_triggered
    
    

# Train and test splits
def train_test_split(data, test_size=0.1):
    """
    Splits the data into train and test sets.
    """
    n = int((1 - test_size) * len(data))
    train_data = data[:n]
    test_data = data[n:]
    return train_data, test_data

def get_batch(data, block_size=64, batch_size=64):
    # generate a small batch of data of inputs x and targets y
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    #x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(model, data, eval_iters=10, block_size=64,batch_size=64):
    model.eval()
    losses = []
    for k in range(eval_iters):
        X, Y = get_batch(data,batch_size=batch_size, block_size=block_size)
        logits, loss = model(X, Y)
        losses += [loss.item()]
    model.train()
    return np.mean(losses)


def early_stopping(metric_list,
            small_window = 32,
            big_window = 1000,
            stop_delta_ratio = 1e-3, verbose=False):
    if len(metric_list) < 2*small_window:
        return False
    # check if chenges within big window and small window are smaller then the ratio
    big_window = max(big_window, 2*small_window)
    last = np.mean(metric_list[-small_window:])
    dl_small =  abs(last - np.mean(metric_list[-2*small_window:-small_window]))
    idx = max(0,len(metric_list)-big_window)
    dl_big = abs(last - np.mean(metric_list[idx:idx+small_window]))
    ratio = dl_small / dl_big
    if verbose: 
        print(f'step: {len(metric_list)}, Loss change ratio: {ratio:.3g}', end='\r')
        # print(f'Loss change ratio: {ratio:.3g}', end='\r')
    return ratio < stop_delta_ratio 



def plot_loss(model_config, history,figsize=(10, 5)):
    """
    Plot the loss history and save it to a file.
    """
    #plot the loss
    plt.figure(figsize=figsize)
    # plt.plot(np.cumsum(time_history), loss_history, label='training')
    plt.plot(np.cumsum(history['time']), history['train_loss'], label='train')
    plt.plot(np.cumsum(history['time']), history['val_loss'], label='val')
    # plt.plot(np.cumsum(valid_time), valid_loss, label='validation')
    plt.xlabel('Time (s)')
    plt.ylabel('Loss')
    # plt.title('Loss history ' + ' layer #:' + str(N_LAYER)+ ' embedding #:' + str(N_EMBD) + ' model: ' + MODEL)
    plt.title(f'Loss history {model_config.run_name}')
    plt.grid()
    plt.xscale('log')
    plt.yscale('log')
    plt.legend()
    plt.savefig(os.path.join(model_config.save_path, f'loss_{model_config.run_name}.pdf'), bbox_inches='tight')
    plt.close()
        
        