import torch
import torch.nn as nn
import os
import random

import sys
sys.path.append("./MusicTransformer_Pytorch")

from third_party.midi_processor.processor import decode_midi

from argument_funcs import parse_trak_args, print_trak_args
from model.music_transformer import MusicTransformer
from dataset.e_piano import create_epiano_datasets, EPianoDatasetSampler
from torch.utils.data import DataLoader

from utilities.constants import *
from utilities.device import get_device, use_cuda

from main import MusicTransformerModelOutput
from tqdm import tqdm
import numpy as np

import time


# main
def main():
    args = parse_trak_args()

    if(args.force_cpu):
        use_cuda(False)
        print("WARNING: Forced CPU usage, expect model to perform slower")
        print("")

    # load the model
    model = MusicTransformer(n_layers=args.n_layers, num_heads=args.num_heads,
                d_model=args.d_model, dim_feedforward=args.dim_feedforward,
                max_sequence=args.max_sequence, rpr=args.rpr,
                enable_new_ver=args.enable_new_ver).to(get_device())


    test_dataset = create_epiano_datasets(args.trak_generated_music_root,
                                          args.test_windows_size,
                                          random_seq=False,  # we use sliding window in order
                                          full_version=True, # we use sliding window on the whole sequence w/o overlapping
                                          split=False)

    MusicTransformerModelOutput.last_k_tokens = args.generate_length
    MusicTransformerModelOutput.cherrypick_token = args.trak_cherrypick
    train_loader = DataLoader(test_dataset, batch_size=args.trak_train_batchsize, num_workers=1) #sampler=sampler)
    print("length of train_loader", len(train_loader))
    assert len(train_loader) == 23

    ckpts_str = [os.path.join(args.trak_model_checkpoints_root, f"{i}/" + args.trak_model_checkpoints_extension) for i in range(args.trak_model_checkpoints_number + 1)]

    for model_id, ckpt_str in enumerate(tqdm(ckpts_str)):
        print(model_id, ckpt_str)
        ckpt = torch.load(ckpt_str)

        model.load_state_dict(ckpt)
        model.eval()

        res = []
        with torch.no_grad():
            for idx, batch in enumerate(train_loader):
                batch = [x.cuda() for x in batch]
                # print(batch[0].shape, batch[1].shape)
                # print(batch[1][-2:])
                loss_fn = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD, reduce=False)
                y = None
                tgt = None
                for i in range(MusicTransformerModelOutput.last_k_tokens):
                    if MusicTransformerModelOutput.cherrypick_token > 0:
                        if i != MusicTransformerModelOutput.cherrypick_token - 1:
                            continue
                        y_iter = model(batch[0][:, i:args.num_prime+i])

                    y_iter = y_iter.permute(0, 2, 1)
                    y_iter = y_iter[:, :, -1:]
                    if y is None:
                        y = y_iter
                    else:
                        y = torch.cat((y, y_iter), axis=2)
                # assert y.shape[2] == MusicTransformerModelOutput.last_k_tokens
                if MusicTransformerModelOutput.cherrypick_token < 0:
                    y = y[:, :, -MusicTransformerModelOutput.last_k_tokens:]
                    tgt = batch[1][:, -MusicTransformerModelOutput.last_k_tokens:]
                else:
                    y = y[:, :, -1:]
                    tgt = batch[1][:, MusicTransformerModelOutput.cherrypick_token+MusicTransformerModelOutput.num_prime-2:MusicTransformerModelOutput.cherrypick_token+MusicTransformerModelOutput.num_prime-1]

                logp_elements = -loss_fn.forward(y, tgt)[:, MusicTransformerModelOutput.start_idx:]
                logp = torch.sum(logp_elements, dim=1)
                res.append(logp - torch.log(1 - torch.exp(logp)))
                if idx % 100 == 0:
                    print(idx)
                if idx > 1000:
                    break
            res = torch.cat(res, dim=0)
        print(res)
        assert res.shape[0] == 178
        if args.trak_cherrypick < 0:
            torch.save(res.cpu(), os.path.join(args.trak_model_checkpoints_root, f"{model_id}/" + "loss_test_generate_seed_0_length_1.pt"))
        else:
            torch.save(res.cpu(), os.path.join(args.trak_model_checkpoints_root, f"{model_id}/" + f"loss_test_generate_seed_0_length_1_cherrypick_{args.trak_cherrypick}.pt"))

if __name__ == "__main__":
    main()
