import os
import torch
import numpy as np
from pathlib import Path

def create_attribution_dirs(root_dir: Path):
    attribution_dir = root_dir / "attribution"
    attribution_dir_plots = attribution_dir / "combined/plots"
    attribution_dir_data = attribution_dir / "combined/data"
    os.makedirs(attribution_dir_plots, exist_ok=True)
    os.makedirs(attribution_dir_data, exist_ok=True)

    ### Discourse features attribution output folder
    df_dir = attribution_dir / "df_analysis"
    df_dir_plots = df_dir / "plots"
    df_dir_data = df_dir / "data"
    os.makedirs(df_dir_plots, exist_ok=True)
    os.makedirs(df_dir_data, exist_ok=True)

    return (attribution_dir, attribution_dir_plots, attribution_dir_data, df_dir, df_dir_plots, df_dir_data)

def get_delay_idxs(run_ids, num_delays):
    delay_idxs = np.zeros((run_ids.shape[0], num_delays), dtype=int)

    run_id_curr = 1
    run_curr_start_idx = 0
    run_trs = 0
    for i, run_id in enumerate(run_ids):
        if run_id_curr < run_id:
            run_id_curr = run_id
            run_curr_start_idx = i
            run_trs = 0
        
        for d in range(num_delays):
            tr_idx = run_curr_start_idx + run_trs - d
            if tr_idx < 0:
                tr_idx = -1
            delay_idxs[i, d] = tr_idx
        
        run_trs += 1
    
    return delay_idxs # (num_trs, num_delays)

def normalize_attributions(attributions: torch.Tensor, type: str = "abs-max") -> torch.Tensor:
    if type == 'min-max':
        min_val = attributions.min(dim=0, keepdim=True).values
        max_val = attributions.max(dim=0, keepdim=True).values
        normalized = 2 * (attributions - min_val) / (max_val - min_val + 1e-10) - 1
    elif type == 'abs-max':
        max_val = attributions.abs().max(dim=0, keepdim=True).values
        normalized = attributions / (max_val + 1e-10)
    elif type == 'l1':
        normalized = attributions / (attributions.abs().sum(dim=0, keepdim=True) + 1e-10)
    elif type == 'l2':
        normalized = attributions / (torch.linalg.norm(attributions, dim=0, ord=2, keepdim=True) + 1e-10)
    return normalized