import numpy as np
import torch
import gc
import torch.nn as nn


def optimize_memory():

    # Clear PyTorch cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Run garbage collection
    gc.collect()


def make_loader(x, y, batch_size=32, shuffle=True, is_float32=True):

    # Check if data is already a tensor with correct dtype
    if isinstance(x, torch.Tensor):
        if x.dtype == torch.float32 and is_float32:
            x_tensor = x
        elif x.dtype == torch.float64 and not is_float32:
            x_tensor = x
        else:
            # Only convert if dtype doesn't match
            x_tensor = x.to(dtype=torch.float32 if is_float32 else torch.float64)
    else:
        # Convert numpy array to tensor with correct dtype
        x_tensor = torch.tensor(x, dtype=torch.float32 if is_float32 else torch.float64)
    
    if isinstance(y, torch.Tensor):
        if len(y.shape) == len(x.shape):
            # Y has same shape as X, use same dtype
            y_tensor = y.to(dtype=torch.float32 if is_float32 else torch.float64)
        else:
            # Y is labels, use long dtype
            y_tensor = y.to(dtype=torch.long)
    else:
        # Convert numpy array to tensor
        if len(y.shape) == len(x.shape):
            y_tensor = torch.tensor(y, dtype=torch.float32 if is_float32 else torch.float64)
        else:
            y_tensor = torch.tensor(y, dtype=torch.long)
    
    tensor_set = torch.utils.data.TensorDataset(x_tensor, y_tensor)
    
    loader = torch.utils.data.DataLoader(tensor_set,
                                         batch_size=batch_size,
                                         shuffle=shuffle)
    return loader


def concatenate_subjects(x, y, fold, is_x_y_equivalent=False):

    X = np.concatenate([x[idx] for idx in fold], axis=-1)
    Y = np.concatenate([y[idx] for idx in fold], axis=-1)

    X = X.transpose((3, 1, 2, 0))  # batch,samples, multiband, channels
    X = X.transpose((0, 2, 1, 3))  # batch,multiband,samples, channels,
    if is_x_y_equivalent:
        Y = Y.transpose((3, 1, 2, 0))  # batch,samples, multiband, channels
        Y = Y.transpose((0, 2, 1, 3))  # batch,multiband,samples, channels,
        return X, Y

    return X, Y - 1


import matplotlib.pyplot as plt
import os


def save_plot_list(
    data,
    save_path,
    plot_type="line",
    title="Plot of Data",
    xlabel="epoch",
    ylabel="loss",
):
    plt.figure(figsize=(8, 4))
    if plot_type == "line":
        plt.plot(data, marker="o")
    elif plot_type == "hist":
        plt.hist(data, bins=10, edgecolor="black")
    else:
        raise ValueError("plot_type must be either 'line' or 'hist'")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(True)
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    plt.savefig(save_path)
    plt.close()


import matplotlib.pyplot as plt
import os


def save_35_subplots(data_list, save_path, titles=None, plot_type="line"):

    # assert len(data_list) == 35, "data_list must have exactly 35 elements"
    # if titles:
    #     assert len(titles) == 35, "titles must have exactly 35 elements"

    rows, cols = len(data_list) // 5, 5
    fig, axes = plt.subplots(rows, cols, figsize=(20, 15))
    axes = axes.flatten()

    for i in range(len(data_list)):
        ax = axes[i]
        data = data_list[i]

        if plot_type == "line":
            ax.plot(data)
        elif plot_type == "hist":
            ax.hist(data, bins=10, edgecolor="black")
        else:
            raise ValueError("plot_type must be 'line' or 'hist'")

        ax.set_title(titles[i] if titles else f"Plot {i+1}")
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()

def set_dropout_rate(model, layer_name, new_p):
    """Set new dropout probability for a named dropout layer."""
    dropout_layer = dict(model.named_modules()).get(layer_name)
    if dropout_layer and isinstance(dropout_layer, nn.Dropout):
        dropout_layer.p = new_p
        print(f"[INFO] Set {layer_name}.p to {new_p}")
    else:
        print(f"[WARNING] Layer '{layer_name}' not found or not a Dropout.")


def loss_with_l2_reg(base_loss_fn, model, l2_lambda):
    def total_loss(pred, target):
        base_loss = base_loss_fn(pred, target)
        l2_norm = sum(
            torch.norm(param, p=2) ** 2
            for param in model.parameters()
            if param.requires_grad
        )
        return base_loss + l2_lambda * l2_norm
    return total_loss

from log import log

def construct_full_period_data_w_mask(reconstructed, target, known_len=50):
    return torch.cat((target, reconstructed[:, :, :, known_len:]), axis=-1)

import os
import yaml
import torch
import torch.nn as nn

def config_to_model_loss_init(dataset_config):
    loss_fn = nn.CrossEntropyLoss()
    if "tmaeformer_" in dataset_config.MODEL_NAME:
        model = tmaeformer_init(dataset_config)
        loss_fn = causal_masked_loss
    else:
        raise ValueError("Model not found")
    with open(os.path.join(dataset_config.exp_dir, "dataset_config.yaml"), "w") as f:
        yaml.dump(vars(dataset_config), f, default_flow_style=False)

    log(f"Model initialization: {model}", dataset_config.log_path)
    return model, loss_fn


def form_dataset(
    dataset_config,
    x_train_meta,
    y_train_meta,
    test_index,
    is_shuffle=True,
    is_float32=True,
    is_x_y_equivalent=False,
):
    x_train_fine, y_train_fine = concatenate_subjects(
        x_train_meta, y_train_meta, test_index, is_x_y_equivalent
    )
    fine_tune_dataloader = make_loader(
        x_train_fine,
        y_train_fine,
        batch_size=dataset_config.BATCH_SIZE,
        shuffle=is_shuffle,
        is_float32=is_float32,
    )
    return fine_tune_dataloader

