"""Theme Transformer Inferencing Code

usage: inference.py [-h] [--model_path MODEL_PATH] --theme THEME
                    [--seq_length SEQ_LENGTH] [--seed SEED]
                    [--out_midi OUT_MIDI] [--cuda] [--max_len MAX_LEN]
                    [--temp TEMP] [--nbars NBARS]
  --model_path MODEL_PATH   model file
  --chord CHORD             chord string(use '_' to seperate)
  --seq_length SEQ_LENGTH   generated seq length
  --seed SEED               random seed (set to -1 to use random seed) (change different if the model stucks)
  --out_midi OUT_MIDI       output midi file
  --cuda                    use CUDA
  --max_len MAX_LEN         number of tokens to predict
  --temp TEMP               temperature
  --nbars NBARS             number of bars to generate

    Author: Joey Zhu
    Email: joe8273@qq.com
    Date: 2024/12/30
"""

import argparse
import numpy as np
import torch
import torch.optim
from model.myModel import myLM
from preprocess.music_data import MusicDataset
from preprocess.tokenizer import RemiPlus
import os
import json
import random
from server.producer import get_producer

CHORD_TRANSFORMER_POP = "Chord-Transformer(pop909)"
CHORD_TRANSFORMER_LMD = "Chord-Transformer(lmd)"
CHOBEL_TRANSFORMER = "Chobel-Transformer"
def set_global_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='./ckpts/0215-1106-restart/model_ep60.pt',  # ./ckpts/12292207-avg-pure/model_ep2000.pt
                    help='model file')

parser.add_argument('--chord', type=str, default='C:maj-F:maj-G:maj-C:maj',
                    help='chord progression')

parser.add_argument('--velocity', type=str, default='mid',
                    help='velocity:low, mid, high prepared')

parser.add_argument('--program', type=str, default='0-0-0',
                    help='program required, use id to choose and - to separate')

parser.add_argument('--seq_length', type=str, default='4096',
                    help='generated seq length')

parser.add_argument('--seed', type=int, default=-1,
                    help='random seed (set to -1 to use random seed)')

parser.add_argument('--out_midi', type=str, default='output.mid',
                    help='output midi file')

parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')

parser.add_argument('--max_len', type=int, default=512, # default:512
                    help='number of tokens to predict')

parser.add_argument('--temp', type=float, default=1.2,
                    help='temperature')

parser.add_argument('--nbars', type=float, default=32,
                    help='number of bars to generate')

parser.add_argument('--combination', type=str, default='avg',
                    help='the combination type in cross and self attention') # weighted-avg  full-connect

args = parser.parse_args()

if not args.seed == -1:
    set_global_random_seed(args.seed)

# create vocab
myvocab = RemiPlus()

# devices
# device = torch.device('cuda:0' if args.cuda else 'cpu')
# device_cpu = torch.device('cpu')


# # model definition
# model = myLM(myvocab.vocab_size,args.combination,d_model=512,num_encoder_layers=6,xorpattern=[0,0,0,1,1,1])

# print("Loading model from {}".format(args.model_path))
# model.load_state_dict(torch.load(args.model_path))
# print("Using device {}".format(device))

def inference(n_bars, strategies, params,chord_seq, max_len, model, 
              device, song_id, user_id, start_kafka=False):
    """inference function

    Args:
        n_bars (int): numbers of bar to generate
        strategies (dict): inferencing strategies
        params (dict): parameters for inferencing strategies
        chord_seq (list): given chord condition
    Returns:
        list: token sequence of generated music
    """
    model.eval() # set the eval mode
    words = [[]]

    word2event = myvocab.id2event

    initial_flag = True
    
    fail_cnt = 0

    input_chord = torch.tensor(chord_seq)
    input_chord = input_chord.reshape((-1,1)) # this is the [[], [], []] way...
    input_chord = input_chord.to(device)

    label_list = []

    previous_labeled = False
    
    bar_count = 0
    with torch.no_grad():
        while bar_count < n_bars:
            print("events #{} Generating Bars #{}/{}".format(len(words[0]),bar_count ,n_bars),end='\r')
            if fail_cnt:
                print ('failed iterations:', fail_cnt)
            
            if fail_cnt >1024:
                print ('model stuck ...\nPlease change a seed sand inference again!')
                return words[0]

            # prepare input
            if initial_flag:
                # no prompt given
                input_x = torch.tensor([chord_seq[0]])
                label_list = [0]
                words[0].append(chord_seq[0])
                label_input = torch.tensor(label_list)

                initial_flag = False
            else:
                input_x = torch.tensor(words[0][-max_len:])
                label_input = torch.tensor(label_list[-max_len:])



            input_x = input_x.reshape((-1,1))
            label_input = label_input.reshape((-1,1))
            
            input_x_att_msk = model.transformer_model.generate_square_subsequent_mask(input_x.shape[0])
            input_x = input_x.to(device)
            label_input = label_input.to(device)
            input_x_att_msk = input_x_att_msk.to(device)
            logits = model(
                src=input_chord,
                tgt=input_x,
                tgt_label=label_input,
                tgt_mask = input_x_att_msk,
            )   
            logits = logits[-1:]
            logits = torch.squeeze(logits)
            logits = logits.cpu().numpy()

            # temperature or not
            if 'temperature' in strategies:
                probs = model.temperature(logits=logits, temperature=params['t'])
            else:
                probs = model.temperature(logits=logits, temperature=1.)

            # sampling
            # word : the generated remi event
            word = model.nucleus(probs=probs, p=params['p'])
            # print("Generated new remi word {}".format(myvocab.id2event[word]))
            # skip padding
            if word in [0]:
                fail_cnt += 1
                continue
            
            # grammar checking ========================================================
            condition = False
            if condition:
                continue
            
            # add new event to record sequence
            words[0].append(word)
            if previous_labeled:
                label_list.append(label_list[-1]+1)
            else:
                label_list.append(0)

            if word2event[word] == "Bar_None":
                bar_count += 1
                if start_kafka:
                    producer.send_message('song-status', {
                        'song_id': song_id,
                        'midi_url': '',
                        'mp3_url': '',
                        'status': 1,
                        'user_id': user_id,
                        'progress': (bar_count / n_bars) * 100,
                    })
                    print(f"progress:{(bar_count / n_bars) * 100} ")
                if bar_count > n_bars:
                    return words[0]
            
            fail_cnt = 0


    print ('generated {} events'.format(len(words[0])))
    return words[0]

# transfer the given info to encoder tokens 

# model.to(device)
    
# given_chord = myvocab.generateChordProgressionToken(args.chord, args.program, args.velocity)

# word_seq = inference(
#             n_bars = args.nbars,
#             strategies=['temperature', 'nucleus'],
#             params={'t': args.temp, 'p': 0.9},
#             chord_seq=given_chord,
#             max_len=args.max_len
# )


# myvocab.REMIID2midi(word_seq,args.out_midi)
# print("{} saved".format(args.out_midi))





"""
下面的为其他模块调用功能
"""

def generate_music(
    model_path,
    song_id = 0,
    user_id = 0,
    chord_progression='C:maj-F:maj-G:maj-C:maj',
    velocity='mid',
    program='0-0-0',
    num_bars=32,
    temperature=1.2,
    max_len=512,
    use_cuda=True,
    seed=-1,
    start_kafka=False,
    upload_oss=False,
    #out_path='./generated/midis/sample/1.mid',
    out_path='./sample/1.mid',
    model_type=CHORD_TRANSFORMER_POP,
    arousal=0.5,
    valence=0.5,
    danceability=0.5,
    energy=0.5,
    instrumentalness=0.5,
    liveness=0.5

    
):
    """
    生成音乐的主函数
    
    参数:
        model_path (str): 模型文件路径
        chord_progression (str): 和弦进行，用'-'分隔
        velocity (str): 力度 ('low', 'mid', 'high')
        program (str): 乐器程序编号，用'-'分隔
        num_bars (int): 生成小节数
        temperature (float): 采样温度
        max_len (int): 预测的最大token数
        use_cuda (bool): 是否使用CUDA
        seed (int): 随机种子 (-1表示使用随机种子)
        output_path (str): 输出MIDI文件路径
    
    返回:
        str: 生成的MIDI文件路径
    """
    if seed != -1:
        set_global_random_seed(seed)
    
    if model_type == CHORD_TRANSFORMER_POP:
        model_path = '/home/jianheng/Workspace/ChobelTransformer/ckpts/12292207-avg-pure/model_ep1000.pt'
        #model_path = './ckpts/12292207-avg-pure/model_ep1000.pt'
    elif model_type == CHORD_TRANSFORMER_LMD:
        model_path = './ckpts/0215-1106-restart/model_ep100.pt'
    elif model_type == CHOBEL_TRANSFORMER:
        model_path = './ckpts/label-t1/model_ep400.pt'

    if start_kafka:
        producer = get_producer()
    
    # 创建词汇表
    myvocab = RemiPlus()
    
    # 设置设备
    device = torch.device('cuda:0' if use_cuda and torch.cuda.is_available() else 'cpu')
    
    # 初始化模型
    model = myLM(myvocab.vocab_size, 'avg', d_model=512, num_encoder_layers=6, xorpattern=[0,0,0,1,1,1])
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    
    # 生成和弦序列
    given_chord = myvocab.generateChordProgressionToken(chord_progression, program, velocity)
    
    # 执行推理
    word_seq = inference(
        n_bars=num_bars,
        strategies=['temperature', 'nucleus'],
        params={'t': temperature, 'p': 0.9},
        chord_seq=given_chord,
        max_len=max_len,
        device=device,
        model=model,
        song_id=song_id,
        user_id=user_id,
        start_kafka=start_kafka, # can change
    )
    oss_config = {
        'access_key_id': 'LTAI5t83t5dJH1nZia1Ktng2',
        'access_key_secret': 'ljqaTa7XPSzWR3fmVZ2RL4RyvmZ5pI',
        'endpoint': 'oss-cn-beijing.aliyuncs.com',
        'bucket_name': 'music-backend'
    }

    if start_kafka:
        out_path = f"mids/{str(song_id)}.mid"
    midi_url, mp3_url = myvocab.REMIID2midi(word_seq, out_path, upload_to_oss=upload_oss, oss_config=oss_config)
    # send to kafka
    # 0, 1, 2 未完成，进行中，已完成
    if start_kafka:
        producer.send_message('song-status', {
            'song_id': song_id,
            'midi_url': midi_url,
            'mp3_url': mp3_url,
            'status': 2,
            'user_id': user_id,
            'progress': 100,
        })  
    
if __name__ == '__main__':
    generate_music(model_path='./ckpts/0215-1106-restart/model_ep80.pt', song_id=1)