import logging
import os
import random
import sys
from collections import defaultdict

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from scipy.io.wavfile import read
from scipy.ndimage import gaussian_filter1d
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

dir_path = os.path.dirname(__file__)
resynth_path = os.path.dirname(dir_path) + "/speech-resynthesis"
sys.path.append(resynth_path)
from dataset import parse_speaker, parse_style
from .utils import F0Stat

MAX_WAV_VALUE = 32768.0
logger = logging.getLogger(__name__)


def quantize_f0(speaker_to_f0, nbins, normalize, log):
    f0_all = []
    for speaker, f0 in speaker_to_f0.items():
        f0 = f0.raw_data
        if log:
            f0 = f0.log()
        mean = speaker_to_f0[speaker].mean_log if log else speaker_to_f0[speaker].mean
        std = speaker_to_f0[speaker].std_log if log else speaker_to_f0[speaker].std
        if normalize == "mean":
            f0 = f0 - mean
        elif normalize == "meanstd":
            f0 = (f0 - mean) / std
        f0_all.extend(f0.tolist())

    hist, bin_x = np.histogram(f0_all, 100000)
    cum_hist = np.cumsum(hist) / len(f0_all) * 100

    bin_offset = []
    bin_size = 100 / nbins
    threshold = bin_size
    for i in range(nbins - 1):
        index = (np.abs(cum_hist - threshold)).argmin()
        bin_offset.append(bin_x[index])
        threshold += bin_size
    bins = np.array(bin_offset)
    bins = torch.FloatTensor(bins)

    return bins


def save_ckpt(model, path, model_class, f0_min, f0_max, f0_bins, speaker_stats):
    ckpt = {
        "state_dict": model.state_dict(),
        "padding_token": model.padding_token,
        "model_class": model_class,
        "speaker_stats": speaker_stats,
        "f0_min": f0_min,
        "f0_max": f0_max,
        "f0_bins": f0_bins,
    }
    torch.save(ckpt, path)


def load_ckpt(path):
    ckpt = torch.load(path)
    ckpt["model_class"]["_target_"] = "emotion_models.pitch_predictor.CnnPredictor"
    model = hydra.utils.instantiate(ckpt["model_class"])
    model.load_state_dict(ckpt["state_dict"])
    model.setup_f0_stats(
        ckpt["f0_min"],
        ckpt["f0_max"],
        ckpt["f0_bins"],
        ckpt["speaker_stats"],
    )
    return model


def freq2bin(f0, f0_min, f0_max, bins):
    f0 = f0.clone()
    f0[f0 < f0_min] = f0_min
    f0[f0 > f0_max] = f0_max
    f0 = torch.bucketize(f0, bins)
    return f0


def bin2freq(x, f0_min, f0_max, bins, mode):
    n_bins = len(bins) + 1
    assert x.shape[-1] == n_bins
    bins = torch.cat([torch.tensor([f0_min]), bins]).to(x.device)
    if mode == "mean":
        f0 = (x * bins).sum(-1, keepdims=True) / x.sum(-1, keepdims=True)
    elif mode == "argmax":
        idx = F.one_hot(x.argmax(-1), num_classes=n_bins)
        f0 = (idx * bins).sum(-1, keepdims=True)
    else:
        raise NotImplementedError()
    return f0[..., 0]


def load_wav(full_path):
    sampling_rate, data = read(full_path)
    return data, sampling_rate


def l1_loss(input, target):
    return F.l1_loss(input=input.float(), target=target.float(), reduce=False)


def l2_loss(input, target):
    return F.mse_loss(input=input.float(), target=target.float(), reduce=False)


class Collator:
    def __init__(self, padding_idx):
        self.padding_idx = padding_idx

    def __call__(self, batch):
        tokens = [item[0] for item in batch]
        lengths = [len(item) for item in tokens]
        tokens = torch.nn.utils.rnn.pad_sequence(
            tokens, batch_first=True, padding_value=self.padding_idx
        )
        f0 = [item[1] for item in batch]
        f0 = torch.nn.utils.rnn.pad_sequence(
            f0, batch_first=True, padding_value=self.padding_idx
        )
        f0_raw = [item[2] for item in batch]
        f0_raw = torch.nn.utils.rnn.pad_sequence(
            f0_raw, batch_first=True, padding_value=self.padding_idx
        )
        spk = [item[3] for item in batch]
        spk = torch.LongTensor(spk)
        gst = [item[4] for item in batch]
        gst = torch.LongTensor(gst)
        mask = tokens != self.padding_idx
        return tokens, f0, f0_raw, spk, gst, mask, lengths


class CnnPredictor(nn.Module):
    def __init__(
        self,
        n_tokens,
        emb_dim,
        channels,
        kernel,
        dropout,
        n_layers,
        spk_emb,
        gst_emb,
        n_bins,
        f0_pred,
        f0_log,
        f0_norm,
    ):
        super(CnnPredictor, self).__init__()
        self.n_tokens = n_tokens
        self.emb_dim = emb_dim
        self.f0_log = f0_log
        self.f0_pred = f0_pred
        self.padding_token = n_tokens
        self.f0_norm = f0_norm
        # add 1 extra embedding for padding token, set the padding index to be the last token
        # (tokens from the clustering start at index 0)
        self.token_emb = nn.Embedding(
            n_tokens + 1, emb_dim, padding_idx=self.padding_token
        )

        self.spk_emb = spk_emb
        self.gst_emb = nn.Embedding(20, gst_emb)
        self.setup = False

        feats = emb_dim + gst_emb
        # feats = emb_dim + gst_emb + (256 if spk_emb else 0)
        layers = [
            nn.Sequential(
                Rearrange("b t c -> b c t"),
                nn.Conv1d(
                    feats, channels, kernel_size=kernel, padding=(kernel - 1) // 2
                ),
                Rearrange("b c t -> b t c"),
                nn.ReLU(),
                nn.LayerNorm(channels),
                nn.Dropout(dropout),
            )
        ]
        for _ in range(n_layers - 1):
            layers += [
                nn.Sequential(
                    Rearrange("b t c -> b c t"),
                    nn.Conv1d(
                        channels,
                        channels,
                        kernel_size=kernel,
                        padding=(kernel - 1) // 2,
                    ),
                    Rearrange("b c t -> b t c"),
                    nn.ReLU(),
                    nn.LayerNorm(channels),
                    nn.Dropout(dropout),
                )
            ]
        self.conv_layer = nn.ModuleList(layers)
        self.proj = nn.Linear(channels, n_bins)

    def forward(self, x, gst=None):
        x = self.token_emb(x)
        feats = [x]

        if gst is not None:
            gst = self.gst_emb(gst)
            gst = rearrange(gst, "b c -> b c 1")
            gst = F.interpolate(gst, x.shape[1])
            gst = rearrange(gst, "b c t -> b t c")
            feats.append(gst)

        x = torch.cat(feats, dim=-1)

        for i, conv in enumerate(self.conv_layer):
            if i != 0:
                x = conv(x) + x
            else:
                x = conv(x)

        x = self.proj(x)
        x = x.squeeze(-1)

        if self.f0_pred == "mean":
            x = torch.sigmoid(x)
        elif self.f0_pred == "argmax":
            x = torch.softmax(x, dim=-1)
        else:
            raise NotImplementedError
        return x

    def setup_f0_stats(self, f0_min, f0_max, f0_bins, speaker_stats):
        self.f0_min = f0_min
        self.f0_max = f0_max
        self.f0_bins = f0_bins
        self.speaker_stats = speaker_stats
        self.setup = True

    def inference(self, x, spk_id=None, gst=None):
        assert (
            self.setup == True
        ), "make sure that `setup_f0_stats` was called before inference!"
        probs = self(x, gst)
        f0 = bin2freq(probs, self.f0_min, self.f0_max, self.f0_bins, self.f0_pred)
        for i in range(f0.shape[0]):
            mean = (
                self.speaker_stats[spk_id[i].item()].mean_log
                if self.f0_log
                else self.speaker_stats[spk_id[i].item()].mean
            )
            std = (
                self.speaker_stats[spk_id[i].item()].std_log
                if self.f0_log
                else self.speaker_stats[spk_id[i].item()].std
            )
            if self.f0_norm == "mean":
                f0[i] = f0[i] + mean
            if self.f0_norm == "meanstd":
                f0[i] = (f0[i] * std) + mean
        if self.f0_log:
            f0 = f0.exp()
        return f0


class PitchDataset(Dataset):
    def __init__(
        self,
        tsv_path,
        km_path,
        substring,
        spk,
        spk2id,
        gst,
        gst2id,
        f0_bins,
        f0_bin_type,
        f0_smoothing,
        f0_norm,
        f0_log,
    ):
        lines = open(tsv_path, "r").readlines()
        self.root, self.tsv = lines[0], lines[1:]
        self.root = self.root.strip()
        self.km = open(km_path, "r").readlines()
        print(f"loaded {len(self.km)} files")

        self.spk = spk
        self.spk2id = spk2id
        self.gst = gst
        self.gst2id = gst2id

        self.f0_bins = f0_bins
        self.f0_smoothing = f0_smoothing
        self.f0_norm = f0_norm
        self.f0_log = f0_log

        if substring != "":
            tsv, km = [], []
            for tsv_line, km_line in zip(self.tsv, self.km):
                if substring.lower() in tsv_line.lower():
                    tsv.append(tsv_line)
                    km.append(km_line)
            self.tsv, self.km = tsv, km
            print(f"after filtering: {len(self.km)} files")

        self.speaker_stats = self._compute_f0_stats()
        self.f0_min, self.f0_max = self._compute_f0_minmax()
        if f0_bin_type == "adaptive":
            self.f0_bins = quantize_f0(
                self.speaker_stats, self.f0_bins, self.f0_norm, self.f0_log
            )
        elif f0_bin_type == "uniform":
            self.f0_bins = torch.linspace(self.f0_min, self.f0_max, self.f0_bins + 1)[
                1:-1
            ]
        else:
            raise NotImplementedError
        print(f"f0 min: {self.f0_min}, f0 max: {self.f0_max}")
        print(f"bins: {self.f0_bins} (shape: {self.f0_bins.shape})")

    def __len__(self):
        return len(self.km)

    def _load_f0(self, tsv_line):
        tsv_line = tsv_line.split("\t")[0]
        f0 = self.root + "/" + tsv_line.replace(".wav", ".yaapt.f0.npy")
        f0 = np.load(f0)
        f0 = torch.FloatTensor(f0)
        return f0

    def _preprocess_f0(self, f0, spk):
        mask = f0 != -999999  # process all frames
        # mask = (f0 != 0)  # only process voiced frames
        mean = (
            self.speaker_stats[spk].mean_log
            if self.f0_log
            else self.speaker_stats[spk].mean
        )
        std = (
            self.speaker_stats[spk].std_log
            if self.f0_log
            else self.speaker_stats[spk].std
        )
        if self.f0_log:
            f0[f0 == 0] = 1e-5
            f0[mask] = f0[mask].log()
        if self.f0_norm == "mean":
            f0[mask] = f0[mask] - mean
        if self.f0_norm == "meanstd":
            f0[mask] = (f0[mask] - mean) / std
        return f0

    def _compute_f0_minmax(self):
        f0_min, f0_max = float("inf"), -float("inf")
        for tsv_line in tqdm(self.tsv, desc="computing f0 minmax"):
            spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
            f0 = self._load_f0(tsv_line)
            f0 = self._preprocess_f0(f0, spk)
            f0_min = min(f0_min, f0.min().item())
            f0_max = max(f0_max, f0.max().item())
        return f0_min, f0_max

    def _compute_f0_stats(self):
        from functools import partial

        speaker_stats = defaultdict(partial(F0Stat, True))
        for tsv_line in tqdm(self.tsv, desc="computing speaker stats"):
            spk = self.spk2id[parse_speaker(tsv_line, self.spk)]
            f0 = self._load_f0(tsv_line)
            mask = f0 != 0
            f0 = f0[mask]  # compute stats only on voiced parts
            speaker_stats[spk].update(f0)
        return speaker_stats

    def __getitem__(self, i):
        x = self.km[i]
        x = x.split(" ")
        x = list(map(int, x))
        x = torch.LongTensor(x)

        gst = parse_style(self.tsv[i], self.gst)
        gst = self.gst2id[gst]
        spk = parse_speaker(self.tsv[i], self.spk)
        spk = self.spk2id[spk]

        f0_raw = self._load_f0(self.tsv[i])
        f0 = self._preprocess_f0(f0_raw.clone(), spk)

        f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]
        f0_raw = F.interpolate(f0_raw.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0]

        f0 = freq2bin(f0, f0_min=self.f0_min, f0_max=self.f0_max, bins=self.f0_bins)
        f0 = F.one_hot(f0.long(), num_classes=len(self.f0_bins) + 1).float()
        if self.f0_smoothing > 0:
            f0 = torch.tensor(
                gaussian_filter1d(f0.float().numpy(), sigma=self.f0_smoothing)
            )
        return x, f0, f0_raw, spk, gst


def train(cfg):
    device = "cuda:0"
    # add 1 extra embedding for padding token, set the padding index to be the last token
    # (tokens from the clustering start at index 0)
    padding_token = cfg.n_tokens
    collate_fn = Collator(padding_idx=padding_token)
    train_ds = PitchDataset(
        cfg.train_tsv,
        cfg.train_km,
        substring=cfg.substring,
        spk=cfg.spk,
        spk2id=cfg.spk2id,
        gst=cfg.gst,
        gst2id=cfg.gst2id,
        f0_bins=cfg.f0_bins,
        f0_bin_type=cfg.f0_bin_type,
        f0_smoothing=cfg.f0_smoothing,
        f0_norm=cfg.f0_norm,
        f0_log=cfg.f0_log,
    )
    valid_ds = PitchDataset(
        cfg.valid_tsv,
        cfg.valid_km,
        substring=cfg.substring,
        spk=cfg.spk,
        spk2id=cfg.spk2id,
        gst=cfg.gst,
        gst2id=cfg.gst2id,
        f0_bins=cfg.f0_bins,
        f0_bin_type=cfg.f0_bin_type,
        f0_smoothing=cfg.f0_smoothing,
        f0_norm=cfg.f0_norm,
        f0_log=cfg.f0_log,
    )
    train_dl = DataLoader(
        train_ds,
        num_workers=0,
        batch_size=cfg.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
    )
    valid_dl = DataLoader(
        valid_ds, num_workers=0, batch_size=16, shuffle=False, collate_fn=collate_fn
    )

    f0_min = train_ds.f0_min
    f0_max = train_ds.f0_max
    f0_bins = train_ds.f0_bins
    speaker_stats = train_ds.speaker_stats

    model = hydra.utils.instantiate(cfg["model"]).to(device)
    model.setup_f0_stats(f0_min, f0_max, f0_bins, speaker_stats)

    optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())

    best_loss = float("inf")
    for epoch in range(cfg.epochs):
        train_loss, train_l2_loss, train_l2_voiced_loss = run_epoch(
            model, train_dl, optimizer, device, cfg, mode="train"
        )
        valid_loss, valid_l2_loss, valid_l2_voiced_loss = run_epoch(
            model, valid_dl, None, device, cfg, mode="valid"
        )
        print(
            f"[epoch {epoch}] train loss: {train_loss:.3f}, l2 loss: {train_l2_loss:.3f}, l2 voiced loss: {train_l2_voiced_loss:.3f}"
        )
        print(
            f"[epoch {epoch}] valid loss: {valid_loss:.3f}, l2 loss: {valid_l2_loss:.3f}, l2 voiced loss: {valid_l2_voiced_loss:.3f}"
        )
        if valid_l2_voiced_loss < best_loss:
            path = f"{os.getcwd()}/pitch_predictor.ckpt"
            save_ckpt(model, path, cfg["model"], f0_min, f0_max, f0_bins, speaker_stats)
            best_loss = valid_l2_voiced_loss
            print(f"saved checkpoint: {path}")
        print(f"[epoch {epoch}] best loss: {best_loss:.3f}")


def run_epoch(model, loader, optimizer, device, cfg, mode):
    if mode == "train":
        model.train()
    else:
        model.eval()

    epoch_loss = 0
    l1 = 0
    l1_voiced = 0
    for x, f0_bin, f0_raw, spk_id, gst, mask, _ in tqdm(loader):
        x, f0_bin, f0_raw, spk_id, gst, mask = (
            x.to(device),
            f0_bin.to(device),
            f0_raw.to(device),
            spk_id.to(device),
            gst.to(device),
            mask.to(device),
        )
        b, t, n_bins = f0_bin.shape
        yhat = model(x, gst)
        nonzero_mask = (f0_raw != 0).logical_and(mask)
        yhat_raw = model.inference(x, spk_id, gst)
        expanded_mask = mask.unsqueeze(-1).expand(-1, -1, n_bins)
        if cfg.f0_pred == "mean":
            loss = F.binary_cross_entropy(
                yhat[expanded_mask], f0_bin[expanded_mask]
            ).mean()
        elif cfg.f0_pred == "argmax":
            loss = F.cross_entropy(
                rearrange(yhat, "b t d -> (b t) d"),
                rearrange(f0_bin.argmax(-1), "b t -> (b t)"),
                reduce=False,
            )
            loss = rearrange(loss, "(b t) -> b t", b=b, t=t)
            loss = (loss * mask).sum() / mask.float().sum()
        else:
            raise NotImplementedError
        l1 += F.l1_loss(yhat_raw[mask], f0_raw[mask]).item()
        l1_voiced += F.l1_loss(yhat_raw[nonzero_mask], f0_raw[nonzero_mask]).item()
        epoch_loss += loss.item()

        if mode == "train":
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

    print(f"{mode} example    y: {f0_bin.argmax(-1)[0, 50:60].tolist()}")
    print(f"{mode} example yhat: {yhat.argmax(-1)[0, 50:60].tolist()}")
    print(f"{mode} example    y: {f0_raw[0, 50:60].round().tolist()}")
    print(f"{mode} example yhat: {yhat_raw[0, 50:60].round().tolist()}")
    return epoch_loss / len(loader), l1 / len(loader), l1_voiced / len(loader)


@hydra.main(config_path=dir_path, config_name="pitch_predictor.yaml")
def main(cfg):
    np.random.seed(1)
    random.seed(1)
    torch.manual_seed(1)
    from hydra.core.hydra_config import HydraConfig

    overrides = {
        x.split("=")[0]: x.split("=")[1]
        for x in HydraConfig.get().overrides.task
        if "/" not in x
    }
    print(f"{cfg}")
    train(cfg)


if __name__ == "__main__":
    main()
