import json
import random

import torch
from torch.nn.functional import one_hot
from torch.utils.data import Dataset


class JSBDataset(Dataset):
    """
    JSB Dataset
    
    Parameters
    ----------
    fname
        name of json file with raw JSB data
    seq_len
        size of windows for training and testing
    num_tokens
        size of one-hot vectors
    split
        one of {'train', 'valid', 'test'}
    """
    def __init__(self, fname, seq_len, num_tokens, split="train"):
        with open(fname) as file:
            raw_data = json.load(file)[split]
        self.seq_len = seq_len
        self.num_tokens = num_tokens

        self.data = []
        for chorale in raw_data:
            for i in range(0, len(chorale), self.seq_len * 2):
                ex = self._window2tensor(chorale[i : i + self.seq_len * 2 : 2])
                if ex.size(0) == self.seq_len:
                    self.data.append(ex)
        # Size = (N, seq_len, 4, num_tokens)
        self.data = torch.stack(self.data)

    def _window2tensor(self, ex):
        ex = torch.tensor(ex)
        ex[:, 0][ex[:, 0] != -1] -= 45  # soprano
        ex[:, 1][ex[:, 1] != -1] -= 42  # alto
        ex[:, 2][ex[:, 2] != -1] -= 38  # tenor
        ex[:, 3][ex[:, 3] != -1] -= 35  # bass
        ex[ex == -1] = 0
        return ex

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        assert isinstance(idx, int)
        inp = self.data[idx, :, :].clone()

        # data augmentation: pitch transposition
        if inp.sum() != 0:
            pitch_del = random.randint(
                max(-3, 1 - torch.min(inp[inp != 0])),
                min(4, 36 - torch.max(inp[inp != 0])),
            )
            inp[inp != 0] += pitch_del

        # label choice: corrupted example (0) or true example (1)
        tgt = random.choice([0, 1])
        if tgt == 0:
            num_corrupts = random.randint(2, 3)
            rs = random.sample(range(self.seq_len * 4), num_corrupts)
            for r in rs:
                idx = (r % self.seq_len, r // self.seq_len)
                corrupt_choice = random.choices(
                    (0, 1, 2, 3), weights=(0.6, 0.2, 0.1, 0.1)
                )[0]
                # Gaussian
                if corrupt_choice == 0:
                    note = round(random.gauss(inp[idx].item(), 3))
                    while note == inp[idx] or note < 1 or note > 36:
                        note = round(random.gauss(inp[idx].item(), 3))
                    inp[idx] = note
                # Harmonically sound
                elif (
                    corrupt_choice == 1
                    and sum(inp[:, idx[1]][inp[:, idx[1]] != 0]) != 0
                    and not all(inp[:, idx[1]][inp[:, idx[1]] != 0] == inp[idx])
                ):
                    note = random.choice(inp[:, idx[1]][inp[:, idx[1]] != 0])
                    while note == inp[idx]:
                        note = random.choice(inp[:, idx[1]][inp[:, idx[1]] != 0])
                    inp[idx] = note
                # Extend prev. note
                elif corrupt_choice == 2 and idx[0] != 0:
                    inp[idx] = inp[idx[0] - 1, idx[1]]
                # Uniform
                else:
                    inp[idx] = random.choice(
                        list(range(0, inp[idx]))
                        + list(range(inp[idx] + 1, self.num_tokens))
                    )

        inp = one_hot(inp, self.num_tokens)
        return inp.float(), tgt
