"""Theme Transformer Inferencing Code

usage: inference.py [-h] [--model_path MODEL_PATH] --theme THEME
                    [--seq_length SEQ_LENGTH] [--seed SEED]
                    [--out_midi OUT_MIDI] [--cuda] [--max_len MAX_LEN]
                    [--temp TEMP] [--nbars NBARS]
  --model_path MODEL_PATH   model file
  --chord CHORD             chord string(use '_' to seperate)
  --seq_length SEQ_LENGTH   generated seq length
  --seed SEED               random seed (set to -1 to use random seed) (change different if the model stucks)
  --out_midi OUT_MIDI       output midi file
  --cuda                    use CUDA
  --max_len MAX_LEN         number of tokens to predict
  --temp TEMP               temperature
  --nbars NBARS             number of bars to generate

    Author: Joey Zhu
    Email: joe8273@qq.com
    Date: 2024/12/30
"""

import argparse
import numpy as np
import torch
import torch.optim
from model.myModel import myLM
from preprocess.music_data import MusicDataset
from preprocess.tokenizer import RemiPlus
import os
import json
import random

def set_global_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='./ckpts/02061509-avg-lmd/model_ep40.pt',  # ./ckpts/12292207-avg-pure/model_ep2000.pt
                    help='model file')

parser.add_argument('--chord', type=str, default='C:maj-F:maj-G:maj-C:maj',
                    help='chord progression')

parser.add_argument('--velocity', type=str, default='mid',
                    help='velocity:low, mid, high prepared')

parser.add_argument('--program', type=str, default='0-0-0',
                    help='program required, use id to choose and - to separate')

parser.add_argument('--seq_length', type=str, default='4096',
                    help='generated seq length')

parser.add_argument('--seed', type=int, default=-1,
                    help='random seed (set to -1 to use random seed)')

parser.add_argument('--out_midi', type=str, default='output.mid',
                    help='output midi file')

parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')

parser.add_argument('--max_len', type=int, default=512, # default:512
                    help='number of tokens to predict')

parser.add_argument('--temp', type=float, default=1.2,
                    help='temperature')

parser.add_argument('--nbars', type=float, default=32,
                    help='number of bars to generate')

parser.add_argument('--combination', type=str, default='avg',
                    help='the combination type in cross and self attention') # weighted-avg  full-connect

args = parser.parse_args()

if not args.seed == -1:
    set_global_random_seed(args.seed)

# create vocab
myvocab = RemiPlus()

# devices
device = torch.device('cuda:0' if args.cuda else 'cpu')
device_cpu = torch.device('cpu')


# model definition
model = myLM(myvocab.vocab_size,args.combination,d_model=512,num_encoder_layers=6,xorpattern=[0,0,0,1,1,1])

print("Loading model from {}".format(args.model_path))
model.load_state_dict(torch.load(args.model_path))
print("Using device {}".format(device))

def inference(n_bars, strategies, params,chord_seq):
    """inference function

    Args:
        n_bars (int): numbers of bar to generate
        strategies (dict): inferencing strategies
        params (dict): parameters for inferencing strategies
        chord_seq (list): given chord condition
    Returns:
        list: token sequence of generated music
    """
    model.eval() # set the eval mode
    words = [[]]

    word2event = myvocab.id2event

    initial_flag = True
    
    fail_cnt = 0

    input_chord = torch.tensor(chord_seq)
    input_chord = input_chord.reshape((-1,1)) # this is the [[], [], []] way...
    input_chord = input_chord.to(device)

    label_list = []

    previous_labeled = False
    
    bar_count = 0

    with torch.no_grad():
        while bar_count < n_bars:
            print("events #{} Generating Bars #{}/{}".format(len(words[0]),bar_count ,n_bars),end='\r')
            if fail_cnt:
                print ('failed iterations:', fail_cnt)
            
            if fail_cnt >1024:
                print ('model stuck ...\nPlease change a seed sand inference again!')
                return words[0]

            # prepare input
            if initial_flag:
                # no prompt given
                input_x = torch.tensor([chord_seq[0]])
                label_list = [0]
                words[0].append(chord_seq[0])
                label_input = torch.tensor(label_list)

                initial_flag = False
            else:
                input_x = torch.tensor(words[0][-args.max_len:])
                label_input = torch.tensor(label_list[-args.max_len:])



            input_x = input_x.reshape((-1,1))
            label_input = label_input.reshape((-1,1))
            
            input_x_att_msk = model.transformer_model.generate_square_subsequent_mask(input_x.shape[0])
            input_x = input_x.to(device)
            label_input = label_input.to(device)
            input_x_att_msk = input_x_att_msk.to(device)


            logits = model(
                src=input_chord,
                tgt=input_x,
                tgt_label=label_input,
                tgt_mask = input_x_att_msk
            )   
            logits = logits[-1:]
            logits = torch.squeeze(logits)
            logits = logits.cpu().numpy()
            



            # temperature or not
            if 'temperature' in strategies:
                probs = model.temperature(logits=logits, temperature=params['t'])
            else:
                probs = model.temperature(logits=logits, temperature=1.)

            # sampling
            # word : the generated remi event
            word = model.nucleus(probs=probs, p=params['p'])
            # print("Generated new remi word {}".format(myvocab.id2event[word]))
            # skip padding
            if word in [0]:
                fail_cnt += 1
                continue
            
            # grammar checking ========================================================
            condition = False
            if condition:
                continue
            
            # add new event to record sequence
            words[0].append(word)
            if previous_labeled:
                label_list.append(label_list[-1]+1)
            else:
                label_list.append(0)

            if word2event[word] == "Bar_None":
                bar_count += 1
                if bar_count > n_bars:
                    return words[0]
            
            fail_cnt = 0


    print ('generated {} events'.format(len(words[0])))
    return words[0]

# transfer the given info to encoder tokens 

model.to(device)
chord_info = ['G:maj-C:maj-D:maj-G:maj',
                  'A:maj-F#:maj-D:maj-E:maj',
                  'G:maj-C:maj-G:maj-D:maj-G:maj',
                  'D:maj-A:maj-B:min-G:maj-D:maj',
                  'D:min-G:maj-A:min-F:maj',
            ]
i = 1
index = [0, 0, 0, 0]
while True:
    
    given_chord = myvocab.generateChordProgressionToken(chord_info[i % 5], args.program, args.velocity)
    
    word_seq = inference(
                n_bars = args.nbars,
                strategies=['temperature', 'nucleus'],
                params={'t': args.temp, 'p': 0.9},
                chord_seq=given_chord,
    )
    # 判断
    # save to disk
    el = list(map(lambda x: myvocab.id2event[x], word_seq))
    # 去除前缀Chord_
    chords = myvocab.getChords(el)
    prefix = 'Chord_'
    chords = [s[len(prefix):] if s.startswith(prefix) else s for s in chords]
    if len(chords) == 0:
        continue
    cp = myvocab.getChordProgression(chords)
    up, down = 0, len(chords)
    if cp is None:
        continue
    # this dicover the best song
    chord = chord_info[i % 5]
    gc = '-'.join(cp)
    stop = False
    if chord == gc:
        myvocab.REMIID2midi(word_seq,f'./generated/midis/pop909/great/best.mid')
        print("{} saved".format(out_path))
        stop = True
    cp = chord_info[i % 5].split('-')
    for chord in chords:
        if chord in cp:
            up += 1
    ratio = up / down
    print(f"generated new song ratio is:{ratio}")
    dir = 0
    if ratio < 0.25:
        if index[0] >= 10:
            continue
        dir, index[0] = 0, index[0] + 1
    elif ratio < 0.5:
        if index[1] >= 20:
            continue
        dir, index[1] = 1, index[1] + 1
    elif ratio < 0.75:
        if index[2] >= 30:
            continue
        dir, index[2] = 2, index[2] + 1
    else:
        if index[3] >= 100: # 61
            continue
        dir, index[3] = 3, index[3] + 1
    out_path = f'./generated/midis/pop909/great/{dir}/{chord_info[i % 5]}_{i}.mid'
    
    myvocab.REMIID2midi(word_seq,out_path)
    print("{} saved".format(out_path))
    i += 1

    if i >= 160 and stop:
        print("finish all tasks")
        break

