from collections import defaultdict
import json
from pathlib import Path
from pathlib import Path
from typing import *
import math

import pandas as pd
import torch
import torchaudio
from tqdm import tqdm
from torchaudio.transforms import Resample

from main.dataset import is_silent


def sdr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
    s_target = torch.norm(target, dim=-1)**2 + eps
    s_error = torch.norm(target - preds, dim=-1)**2 + eps
    return 10 * torch.log10(s_target/s_error)


def sisnr(preds: torch.Tensor, target: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    alpha = (torch.sum(preds * target, dim=-1, keepdim=True) + eps) / (torch.sum(target**2, dim=-1, keepdim=True) + eps)
    target_scaled = alpha * target
    noise = target_scaled - preds
    s_target = torch.sum(target_scaled**2, dim=-1) + eps
    s_error = torch.sum(noise**2, dim=-1) + eps
    return 10 * torch.log10(s_target / s_error)


def load_chunks(chunk_folder: Path) -> Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], int]:
    original_tracks_and_rate = {ori.name.split(".")[0][3:]: torchaudio.load(ori) for ori in sorted(list(chunk_folder.glob("ori*.wav")))}
    separated_tracks_and_rate = {sep.name.split(".")[0][3:]: torchaudio.load(sep) for sep in sorted(list(chunk_folder.glob("sep*.wav")))}
    assert tuple(original_tracks_and_rate.keys()) == tuple(separated_tracks_and_rate.keys())

    original_tracks = {k:t for k, (t,_) in original_tracks_and_rate.items()}
    sample_rates_ori = [s for (_,s) in original_tracks_and_rate.values()]

    separated_tracks = {k:t for k, (t,_) in separated_tracks_and_rate.items()}
    sample_rates_sep = [s for (_,s) in separated_tracks_and_rate.values()]

    assert len({*sample_rates_ori, *sample_rates_sep}) == 1, print({*sample_rates_ori, *sample_rates_sep})
    assert len(original_tracks) == len(separated_tracks)
    sr = sample_rates_ori[0]

    return original_tracks, separated_tracks, sr


def evaluate_separations(
    separation_path: Union[str, Path],
    orig_sr: int = 44100, 
    resample_sr: Optional[int] = None, 
    filter_single_source: bool = True, 
    eps: float = 1e-8,
    chunk_duration: float = 4.0, 
    overlap_duration: float = 2.0) -> pd.DataFrame:

    separation_folder = Path(separation_path)
    assert separation_folder.exists(), separation_folder
    assert (separation_folder / "chunk_data.json").exists(), separation_folder

    # Load info about separated chunks
    with open(separation_folder / "chunk_data.json") as f:
        chunk_data = json.load(f)
    
    df_entries = defaultdict(list)
    track_to_chunks = defaultdict(list)
    
    for chunk_data in chunk_data:
        track = chunk_data["track"]
        chunk_idx = chunk_data["chunk_index"]
        start_sample = chunk_data["start_chunk_sample"]
        track_to_chunks[track].append( (start_sample, chunk_idx) )
  
    for chunk_folder in tqdm(list(separation_path.iterdir())):
                        
        original_tracks, separated_tracks, sr = load_chunks(chunk_folder)   
        assert sr == orig_sr, f"chunk [{chunk_folder.name}]: expected freq={orig_sr}, track freq={sr}"   
        
        mixture = sum([owav for owav in original_tracks.values()])
        chunk_samples = int(chunk_duration * orig_sr)
        overlap_samples = int(overlap_duration * orig_sr)

        # Calculate the step size between consecutive sub-chunks
        step_size = chunk_samples - overlap_samples

        # Determine the number of sub-chunks based on step_size
        num_subchunks = math.ceil((mixture.shape[-1] - overlap_samples) / step_size)
            
        for i in range(num_subchunks):
            start_sample = i * step_size
            end_sample = start_sample + chunk_samples
            
            # Determine number of active signals in sub-chunk
            num_active_signals = 0
            for k in separated_tracks:
                o = original_tracks[k][:,start_sample:end_sample]
                if not is_silent(o):
                    num_active_signals += 1
            
            # Skip sub-chunk if necessary
            if filter_single_source and num_active_signals <= 1:
                continue

            # Compute SI-SNRi for each stem
            for k in separated_tracks:
                o = original_tracks[k][:,start_sample:end_sample]
                s = separated_tracks[k][:,start_sample:end_sample]
                m = mixture[:,start_sample:end_sample]
                df_entries[k].append((sisnr(s, o, eps) - sisnr(m, o, eps)).item())
            
            # Add chunk and sub-chunk info to dataframe entry
            df_entries["chunk_n"].append(chunk_folder.name)
            df_entries["start_sample"].append(start_sample)
            df_entries["end_sample"].append(end_sample)

    # Create and return dataframe
    return pd.DataFrame(df_entries)