import os
from datetime import datetime
import uuid
import numpy as np
import torch
# import torch.nn.functional as F
# from train import validate
# from plot import plot_prediction, plot_lograte, plot_flowfield, plot_heatmaps

def merge_windows(data, overlap):
    """
    Given windowed data of shape (num_segments, segment_length, num_features),
    merges the segments back into a single continuous signal of shape 
    (orig_len, num_features). The overlapping regions of segments will be 
    weighted equally.

    Parameters
    ----------
    data : np.ndarray
        The input data array.
    overlap : int
        The overlap between segments.

    Returns
    -------
    np.ndarray
        The merged data array (orig_len, num_features)
    """
    # Ensure data is type numpy
    curr_type = type(data)
    if curr_type is not np.ndarray:
        data = data.numpy()
    
    # Get dimensionality
    num_segments, segment_length, num_features = data.shape
    # Get original data length
    orig_len = (num_segments - 1) * (segment_length - overlap) + segment_length
    unwindowed_arr = np.zeros((orig_len, num_features))
    weights = np.zeros((orig_len, num_features))

    # Add window values and count overlaps
    stride = segment_length - overlap
    for i, window in enumerate(data):
        np.add.at(unwindowed_arr,
                  np.arange(i * stride, i * stride + segment_length),
                  window)
        np.add.at(weights,
                  np.arange(i * stride, i * stride + segment_length),
                  1)

    # Average the overlapping elements
    result = unwindowed_arr / weights

    # If original type was a tensor, convert back to tensor
    if curr_type is torch.Tensor:
        return torch.from_numpy(result)
    else:
        return result

def setup_model_directories(model_type, readout_type, epochs, database):
    """
    Set up directory structure for model results and plots.
    """
    unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{str(uuid.uuid4())[:8]}"
    base_dir = os.path.join('results', database, model_type, readout_type, f"{epochs}epochs_{unique_id}")
    
    dirs = {
        'plots': os.path.join(base_dir, 'plots'),
        'checkpoints': os.path.join(base_dir, 'checkpoints'),
        'logs': os.path.join(base_dir, 'logs'),
        'metrics': os.path.join(base_dir, 'metrics')
    }
    
    for dir_path in dirs.values():
        os.makedirs(dir_path, exist_ok=True)
    
    return {
        'BASE_PATH': dirs['plots'],
        'CHECKPOINT_PATH': dirs['checkpoints'],
        'LOG_PATH': dirs['logs'],
        'METRICS_PATH': dirs['metrics'],
        'UNIQUE_ID': unique_id
    } 

# def generate_plots(model, data, config, device=None):
#     data_type = config.get('data_type', '')
#     neurons_map = config.get('neurons', {})
#     should_plot_heatmap = config.get('heatmap', False)
    
#     config['data'] = 'val' 
    
#     pred_latents, pred_rates, spikes, _, pred_logrates, influence = validate(model, data, config=config, device=device)
    
#     # Ensure pred_latents is converted to numpy only if it's not None
#     numpy_pred_latents = pred_latents.detach().cpu().numpy() if pred_latents is not None else None

#     # Move other relevant tensors to CPU before passing to plotting functions
#     # that might iterate and call .item() or expect CPU data.
#     spikes_cpu = spikes.cpu() if spikes is not None else None
#     pred_rates_cpu = pred_rates.cpu() if pred_rates is not None else None
#     pred_logrates_cpu = pred_logrates.cpu() if pred_logrates is not None else None

#     for name, _ in neurons_map.items():
#         if spikes_cpu is not None and pred_rates_cpu is not None:
#             plot_prediction(spikes_cpu, pred_rates_cpu, config=config, target=name)
#         if pred_logrates_cpu is not None:
#             plot_lograte(pred_logrates_cpu, config=config, target=name)
    
#     if 'wang' in data_type.lower() and numpy_pred_latents is not None:
#         plot_flowfield(model, numpy_pred_latents, config) # moves model to cpu

#     if should_plot_heatmap:
#         weights_for_heatmap = None
#         script_model_type = config.get('model_type', 'Unknown') # Get model_type from config for warnings

#         # Strategy 1: Monotonic Readout (common in RNNs with linear/monotonic readout)
#         if hasattr(model, 'monotonic_readout') and \
#            hasattr(model.monotonic_readout, 'log_weight') and \
#            hasattr(model.monotonic_readout, 'freeze_mask'):
#             weights_for_heatmap = F.softplus(model.monotonic_readout.log_weight).detach().cpu().numpy() * \
#                                   model.monotonic_readout.freeze_mask.detach().cpu().numpy()

#         # Strategy 2: MLP Readout (common in NODE/RNN with MLP/MiniMLP readout)
#         elif hasattr(model, 'readout_mlp'):
#             if hasattr(model.readout_mlp, '0') and isinstance(model.readout_mlp, torch.nn.Sequential) and hasattr(model.readout_mlp[0], 'weight'): # Sequential MLP
#                 weights_for_heatmap = model.readout_mlp[0].weight.detach().cpu().numpy()
#             elif isinstance(model.readout_mlp, torch.nn.ModuleList) and len(model.readout_mlp) > 0 and hasattr(model.readout_mlp[0], 'weight'): # ModuleList MLP
#                 weights_for_heatmap = model.readout_mlp[0].weight.detach().cpu().numpy()
#             elif hasattr(model.readout_mlp, 'weight'): # Direct MLP layer (not a Sequential or ModuleList)
#                  weights_for_heatmap = model.readout_mlp.weight.detach().cpu().numpy()
        
#         # Strategy 3: Linear Readout Layer 'readout_l' (common in NODE Linear script)
#         elif hasattr(model, 'readout_l') and hasattr(model.readout_l, 'weight'):
#             weights_for_heatmap = model.readout_l.weight.detach().cpu().numpy()
        
#         # Strategy 4: Generic 'readout' attribute if it's a nn.Linear
#         elif hasattr(model, 'readout') and hasattr(model.readout, 'weight'):
#              weights_for_heatmap = model.readout.weight.detach().cpu().numpy()
        
#         if weights_for_heatmap is not None:
#             plot_heatmaps(weights_for_heatmap, config)
#         else:
#             print(f"Warning: Could not determine weights for heatmap plotting for model type '{script_model_type}' using model instance {type(model)}. Please check model structure or update 'generate_plots' in training/utils.py.") 