import functools
import json
from copy import copy
from pathlib import Path
from typing import *

import torch
import yaml
from audio_diffusion_pytorch import KarrasSchedule

from main.dataset import *
from main.module_base import Model
from main.separation import *
from script.misc import load_context, load_model

ROOT_PATH = Path(__file__).parent.parent.resolve().absolute()

def load_diffusion_model(path: str, hparams: dict, device: str = "cpu") -> Model:
    model = Model(**{**hparams}).to(device)
    model.load_state_dict(torch.load(path, map_location=device)["state_dict"])
    return model


def get_subdataset(dataset: SeparationDataset, num_samples: Optional[int] = None, seed:int = 1) -> Tuple[SeparationDataset, Sequence[int]]:
    if num_samples is not None:
        generator = torch.Generator().manual_seed(seed)
        indices = torch.randperm(len(dataset), dtype=torch.int, generator=generator)[:num_samples].tolist()
    else:
        indices = list(range(len(dataset)))
    return ChunkedSeparationSubset(dataset, indices), indices


@torch.no_grad()
def separate_slakh_weakly(
    dataset_path: str,
    output_dir: str,
    num_samples: int = None,
    num_resamples: int = 1,
    num_steps: int = 150,
    batch_size: int = 16,
    resume: bool = True,
    device: float = torch.device("cuda:0"),
    s_churn: float = 20.0,
    source_id: int = -1,
    use_gaussian: bool = False,
    gamma: float = 1.0,
    ):
    config = copy(locals())
    output_dir = Path(output_dir)

    dataset = ChunkedSupervisedDataset(
        audio_dir=dataset_path,
        stems=["bass", "drums", "guitar", "piano"],
        sample_rate=44100,
        max_chunk_size=262144 * 2,
        min_chunk_size=262144 * 2,
    )

    model_bass = load_model(ROOT_PATH / "ckpts/laced-dream-329-(SLAKH_BASS_v2)-epoch=443.ckpt", device)
    model_guitar = load_model(ROOT_PATH / "ckpts/honest-fog-332-(SLAKH_GUITAR_v2)-epoch=407.ckpt", device)
    model_drums = load_model(ROOT_PATH / "ckpts/ancient-voice-289-(SLAKH_DRUMS_v2)-epoch=258.ckpt", device)
    model_piano = load_model(ROOT_PATH / "ckpts/ruby-dew-290-(SLAKH_PIANO_v2)-epoch=236.ckpt", device)

    if use_gaussian:
        diff_fn = lambda x, sigma, denoise_fn, mixture: differential_with_gaussian(x, sigma, denoise_fn, mixture, lambda s:gamma*s)
    else:
        diff_fn = functools.partial(differential_with_dirac, source_id=source_id)
        
    separator = WeaklyMSDMSeparator(
        stem_to_model={
            "bass": model_bass,
            "drums": model_drums,
            "guitar": model_guitar,
            "piano": model_piano,
        },
        sigma_schedule=KarrasSchedule(sigma_min=1e-4, sigma_max=1.0, rho=7.0),
        differential_fn=diff_fn,
        s_churn=s_churn,
        num_resamples=num_resamples,
    )

    separate_slakh(
        output_dir=output_dir,
        dataset=dataset,
        separator=separator,
        sample_rate=22050,
        num_samples=num_samples,
        num_steps=num_steps,
        batch_size=batch_size,
        resume=resume,
    )
    
    with open(output_dir/"config.yaml", "w") as f:
        yaml.dump(config, f)


@torch.no_grad()
def separate_slakh_msdm(
    dataset_path: str,
    model_path: str,
    output_dir: str,
    num_samples: int = None,
    num_resamples: int = 1,
    num_steps: int = 150,
    batch_size: int = 16,
    resume: bool = True,
    device: float = torch.device("cuda:0"),
    s_churn: float = 20.0,
    source_id: int = -1,
    sigma_min: float = 1e-4,
    sigma_max: float = 1.0,
    use_gaussian: bool = False,
    gamma: float = 1.0,
    ):
    config = copy(locals())
    output_dir = Path(output_dir)

    dataset = ChunkedSupervisedDataset(
        audio_dir=dataset_path,
        stems=["bass", "drums", "guitar", "piano"],
        sample_rate=44100,
        max_chunk_size=262144 * 2,
        min_chunk_size=262144 * 2,
    )

    model = load_context(model_path, device)

    if use_gaussian:
        diff_fn = lambda x, sigma, denoise_fn, mixture: differential_with_gaussian(x, sigma, denoise_fn, mixture, lambda s:gamma*s)
    else:
        diff_fn = functools.partial(differential_with_dirac, source_id=source_id)
    
    separator = MSDMSeparator(
        model=model,
        stems=["bass", "drums", "guitar", "piano"],
        sigma_schedule=KarrasSchedule(sigma_min=sigma_min, sigma_max=sigma_max, rho=7.0),
        differential_fn=diff_fn,
        s_churn=s_churn,
        num_resamples=num_resamples,
    )

    separate_slakh(
        output_dir=output_dir,
        dataset=dataset,
        separator=separator,
        sample_rate=22050,
        num_samples=num_samples,
        num_steps=num_steps,
        batch_size=batch_size,
        resume=resume,
    )
    
    with open(output_dir/"config.yaml", "w") as f:
        yaml.dump(config, f)


@torch.no_grad()
def separate_slakh(
        output_dir: Union[str, Path],
        dataset: SeparationDataset,
        separator: Separator,
        sample_rate: Optional[int] = None,
        num_samples: Optional[int] = None,
        num_steps: int = 150,
        batch_size: int = 16,
        resume: bool = False,
    ):

    # Get supsample of dataset
    resampled_dataset, indices = get_subdataset(dataset, num_samples)
        
    # Resample dataset
    resampled_dataset = ResampleDataset(dataset=resampled_dataset, new_sample_rate=sample_rate)

    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)

    # Create chunks metadata
    chunk_data = []
    for i in range(len(indices)):
        start_sample, end_sample = dataset.get_chunk_indices(indices[i])
        chunk_data.append(
            {
                "chunk_index": i,
                "track": dataset.get_chunk_track(indices[i]),
                "start_chunk_sample": start_sample,
                "end_chunk_sample": end_sample,
                "start_chunk_seconds": start_sample / dataset.sample_rate,
                "end_chunk_in_seconds": end_sample / dataset.sample_rate,
            }
        )

    # Save chunk metadata
    with open(output_dir / "chunk_data.json", "w") as f:
        json.dump(chunk_data, f)

    # Separate chunks
    separate_dataset(
        dataset=resampled_dataset,
        separator=separator,
        save_path=output_dir,
        num_steps=num_steps,
        batch_size=batch_size,
        resume=resume
    )