from kafka import KafkaConsumer
import json
from typing import Dict
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import logging
from inference import generate_music

class CreateSong:
    def __init__(self, user_id: str, song_id: str, model_type: str, chord: str, 
                 velocity: int, program: int, nbars: int, arousal: float, valence: float, danceability: float, energy: float, instrumentalness: float, liveness: float):
        self.user_id = user_id
        self.song_id = song_id
        self.model_type = model_type
        self.chord = chord
        self.velocity = velocity
        self.program = program
        self.nbars = nbars
        self.arousal = arousal
        self.valence = valence
        self.danceability = danceability
        self.energy = energy
        self.instrumentalness = instrumentalness
        self.liveness = liveness
class KafkaMessageConsumer:
    def __init__(self, bootstrap_servers: str, topic: str):
        self.consumer = KafkaConsumer(
            topic,
            bootstrap_servers=bootstrap_servers,
            auto_offset_reset='earliest',
            enable_auto_commit=True,
            group_id='music_generation_group',
            value_deserializer=lambda x: json.loads(x.decode('utf-8'))
        )
        self.logger = logging.getLogger(__name__)

    def process_message(self, message_value: Dict):
        try:
            print(message_value)
            # 创建 CreateSong 对象
            song = CreateSong(
                user_id=message_value['user_id'],
                song_id=message_value['song_id'],
                model_type=message_value['model_type'],
                chord=message_value['chord'],
                velocity=message_value['velocity'],
                program=message_value['program'],
                nbars=int(message_value['nbars']),
                arousal=message_value['arousal'],
                valence=message_value['valence'],
                danceability=message_value['danceability'],
                energy=message_value['energy'],
                instrumentalness=message_value['instrumentalness'],
                liveness=message_value['liveness']
            )
            
            # TODO: 调用 inference 服务进行音乐生成
            self.logger.info(f"Processing song generation request: {song.__dict__}")
            generate_music(
                user_id=int(song.user_id),
                song_id=int(song.song_id),
                velocity = song.velocity,
                chord_progression=song.chord,
                num_bars=song.nbars,
                arousal=song.arousal,
                valence=song.valence,
                danceability=song.danceability,
                energy=song.energy,
                instrumentalness=song.instrumentalness,
                liveness=song.liveness,
                model_type=song.model_type,
                start_kafka=True,
                upload_oss=True,
            )
            
        except Exception as e:
            self.logger.error(f"Error processing message: {str(e)}")

    def start_consuming(self):
        self.logger.info("Starting Kafka consumer...")
        try:
            for message in self.consumer:
                self.logger.info(f"Received message: {message.value}")
                # 多携程
                self.process_message(message.value)
        except Exception as e:
            self.logger.error(f"Consumer error: {str(e)}")

def main():
    # 配置日志
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    # 创建并启动消费者
    consumer = KafkaMessageConsumer(
        bootstrap_servers='localhost:9092',  # 根据实际 Kafka 配置修改
        topic='song-create'       # 根据实际 topic 名称修改
    )
    consumer.start_consuming()
 
if __name__ == "__main__":
    main()
