from pythonosc import osc_server
from pythonosc import dispatcher
import argparse
import time
import threading
import signal
import json
import sys
import keyboard
import copy
import socket
import os
import sounddevice as sd
import soundfile as sf
import numpy as np
import functools
from operator import sub
ADDRESS_LIST = ["/HT", "/HR", "/HRQ", "/ELR", "/ERR", "/W"]

class ARKitReceiver:
    def __init__(self, args):
        self.disp = dispatcher.Dispatcher()
        self.type = args.type
        self.latest_data =  {"/W": list(0 for _ in range(52))}
        self.last_received_time = 0
        self.data_lock = threading.Lock()
        self.facecap2arkit = [2, 0, 1, 3, 4, 16, 17, 10, 11, 12, 13, 14, 15, 8, 9, 18, 19, 20, 21, 5, 6, 7, 49, 50,
           24, 22, 23, 25, 31, 37, 32, 38, 40, 39, 42, 41, 26, 43, 44, 29, 30, 27, 28, 47, 48, 33, 34, 35, 36, 45, 46, 51]
        self.is_recording = False
        self.saved_data = []
        self.neutral_data = list(0 for _ in range(52))
        self.ip_renderer = args.ip_renderer

        if args.type == "snapshot":
            keyboard.add_hotkey("r", self.save_snapshot)
            callback = self.data_receiver
            signal.signal(signal.SIGINT, self.signal_handler)
            print("Press 'r' to take a snapshot on current blendshape values.\nPress Ctrl+C to stop the server and save the data.")
        elif args.type in ["record", "record_audio"]:
            if args.type == "record_audio":
                self.audio_frames   = []   # list of numpy arrays
                self.audio_fs = 44100
                self.start_ts = None # recording start time
            keyboard.add_hotkey("r", self.toggle_record)
            self.save_thread = threading.Thread(target=functools.partial(self.record_loop, interval=1/60), daemon=True)
            self.save_thread.start()
            callback = self.data_receiver
            signal.signal(signal.SIGINT, self.signal_handler)
            print("Press 'r' to toggle recording ON/OFF.\nPress Ctrl+C to stop the server and save the data.")
        elif args.type == "live":
            self.forward_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            self.latest_data["/ER"] = [0, 0, 0, 0, 0, 0]
            self._frame_count = 0
            self._fps_start_time = time.time()
            print("live mode enabled: forwarding to WSL on 127.0.0.1:9000.")
            callback = self.data_forwarder

        for address in ADDRESS_LIST:
            self.disp.map(address, callback)
        
        keyboard.add_hotkey("a", self.align_neutral_face)
        self.server = osc_server.ThreadingOSCUDPServer((args.ip, args.port), self.disp)
        print(f"OSC Server is running on {args.ip}:{args.port}.")
        self.server.serve_forever()
    
    @property
    def has_data(self):
        return self.last_received_time != 0
    
    def data_receiver(self, address, *args):
        with self.data_lock:
            if address == "/W":
                self.latest_data[address][self.facecap2arkit[int(args[0])]] = args[1]
            else:
                self.latest_data[address] = args
            self.last_received_time = time.time()
    
    def align_neutral_face(self):
        with self.data_lock:
            print("Align neutral face.")
            self.neutral_data = copy.deepcopy(self.latest_data["/W"])
    
    def data_forwarder(self, address, *args):
        with self.data_lock:
            if address == "/W":
                self.latest_data[address][self.facecap2arkit[int(args[0])]] = args[1]
                forward_data = {address: self.latest_data[address]}
            elif address == "/ELR":
                self.latest_data["/ER"][0:2] = args
                forward_data = {"/ER": self.latest_data["/ER"]}
            elif address == "/ERR":
                self.latest_data["/ER"][3:5] = args
                forward_data = {"/ER": self.latest_data["/ER"]}
            else:
                forward_data = {address: args}
        try:
            msg = json.dumps(forward_data).encode()
            # self.forward_socket.sendto(msg, ("172.26.2.35", 9000))
            self.forward_socket.sendto(msg, (self.ip_renderer, 9000))
        except Exception as e:
            print(f"live forward error: {e}")
        
        if address == "/W" and int(args[0]) == 0:
            now = time.time()
            self._frame_count += 1
            elapsed = now - self._fps_start_time
            if elapsed >= 1.0:
                fps = self._frame_count / elapsed
                print(f"[FPS] {fps:.1f}")
                self._frame_count = 0
                self._fps_start_time = now
    
    def record_loop(self, interval):
        while True: 
            time.sleep(interval)
            with self.data_lock:
                timestamp = time.time()
                if self.has_data and timestamp - self.last_received_time  < interval:
                    if self.is_recording:
                        aligned_data = [sub(*x) for x in zip(self.latest_data["/W"], self.neutral_data)]
                        snapshot = {"timestamp": timestamp - self.start_ts if self.start_ts else timestamp, 
                                    **copy.deepcopy(self.latest_data)
                        }
                        snapshot.update({"/W": aligned_data.copy()})
                        self.saved_data.append(snapshot)
                        print(f"[{timestamp:.3f}] Frame Saved.")
                    else:
                        print(f"[{timestamp:.3f}] Not recording, skipping save.")
                else:
                    print(f"[{timestamp:.3f}] No new data, skipping save.")

    
    def save_snapshot(self):
        with self.data_lock:
            if not self.has_data:
                print("No data received yet. Cannot save snapshot.")
                return
            timestamp = time.time()
            snapshot = {"timestamp": timestamp, **copy.deepcopy(self.latest_data)}
            self.saved_data.append(snapshot)
            print(f"Number [{len(self.saved_data)-1}] Snapshot Saved.")

    def save_to_file(self):
        if len(self.saved_data) == 0:
            print("No data to save.")
            return
        timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time()))
        os.makedirs(timestamp, exist_ok=True)
        
        # save audio
        if self.type == "record_audio":
            self.save_audio(timestamp)
        
        # save arkit
        file_path= os.path.join(timestamp, "arkit_raw.json")
        with open(file_path, "w") as f:
            json.dump(self.saved_data, f, indent=2)
        print(f"\nSaved {len(self.saved_data)} frames to {file_path}")
        
        if self.type in ["record", "record_audio"]:
            with open(os.path.join(timestamp, "log.json"), "w") as f:
                log_data = {
                    "duration": round(self.end_ts - self.start_ts, 3) if self.start_ts and self.end_ts else None,
                    "frame_count": len(self.saved_data)
                }
                json.dump(log_data, f, indent=2)

    def save_audio(self, timestamp):
        if self.audio_frames:
            audio = np.concatenate(self.audio_frames, axis=0)
            wav_path = os.path.join(timestamp, "audio.wav")
            sf.write(wav_path, audio, self.audio_fs)
            print(f"Audio saved to {wav_path}")
        else:
            print("Warning: no audio captured.")
            
    def signal_handler(self, sig, frame):
        print("\nStopping server and saving data...")
        self.save_to_file()
        sys.exit(0)

    def toggle_record(self):
        self.is_recording = not self.is_recording
        if self.is_recording:
            self.start_ts = time.time()
            if self.type == "record_audio":
                self.audio_frames.clear()
                self._start_audio()
        else:
            self.end_ts = time.time()
            if self.type == "record_audio":
                self._stop_audio()
        print(f"\n[Toggle] Recording: {'ON' if self.is_recording else 'OFF'}")
    
    def _audio_callback(self, indata, frames, time_info, status):
        if self.is_recording:
            self.audio_frames.append(indata.copy())
    
    def _start_audio(self):
        print("Starting audio recording...")
        self.audio_recorder = sd.InputStream(
            samplerate=self.audio_fs,
            channels=1,
            callback=self._audio_callback
        )
        self.audio_recorder.start()
    
    def _stop_audio(self):
        if self.audio_recorder is not None:
            self.audio_recorder.stop()
            self.audio_recorder.close()
            self.audio_recorder = None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ip",
        default="0.0.0.0", help="The ip to listen on incoming Face cap data.")
    parser.add_argument("--port",
        type=int, default=8080, help="The port to listen on incoming Face cap data.")
    parser.add_argument("--type", 
        choices=["snapshot", "record", "record_audio", "live"], default="record" , help="The type of operation: 'snapshot', 'record', 'record_audio', 'live'")
    parser.add_argument("--fps", default=60, type=int, help="Frames per second for recording mode. (Frame alignment is still required to ensure consistent time intervals.)")
    parser.add_argument("--ip_renderer", default="127.0.0.1", help="The IP address of the live renderer (GUI) in live mode.")
    args = parser.parse_args()
    receiver = ARKitReceiver(args)