import serial
import time
import struct
import datetime
import threading
import pandas as pd
import os
import atexit

class EnergyMonitor:
    """Energy consumption monitor for computationally intensive workloads"""
    
    # UART Commands
    CMD_WELCOME = 0x00
    CMD_READ_ID = 0x01
    CMD_READ_SENSORS = 0x02
    CMD_READ_SENSOR_VALUES = 0x03
    CMD_READ_CONFIG = 0x04
    
    def __init__(self, port='/dev/ttyUSB0', sampling_rate=1000, output_dir='energy_data'):
        """Initialize the energy monitor
        
        Args:
            port: Serial port for the PMD-USB device
            sampling_rate: Target sampling rate in Hz (default=1000)
            output_dir: Directory to store energy data CSV files
        """
        self.port = port
        self.sampling_rate = sampling_rate
        self.output_dir = output_dir
        self.serial = None
        self.sensors = ['PCIE1', 'PCIE2', 'EPS1', 'EPS2']
        
        self.running = False
        self.samples = []
        self.monitor_thread = None
        self.task_id = None
        
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)
        
        # Register cleanup handler
        atexit.register(self.cleanup)
    
    def cleanup(self):
        """Ensure monitoring is stopped and serial port is closed"""
        if self.running:
            self.stop_monitoring()
        
        if self.serial:
            try:
                self.serial.close()
            except:
                pass
    
    def connect(self):
        """Connect to the energy monitoring device"""
        try:
            self.serial = serial.Serial(
                port=self.port,
                baudrate=115200, # 921600 
                bytesize=8,
                parity='N',
                stopbits=1,
                timeout=0.0001
            )
            
            # Check device response
            self.serial.reset_input_buffer()
            self.serial.write(bytes([self.CMD_WELCOME]))
            time.sleep(0.01)
            response = self.serial.read(self.serial.in_waiting)
            
            if response:
                try:
                    welcome = response.decode('ascii', errors='replace').strip()
                    if "PMD-USB" in welcome:
                        return True
                except:
                    pass
            
            # Try ID command as fallback
            self.serial.reset_input_buffer()
            self.serial.write(bytes([self.CMD_READ_ID]))
            time.sleep(0.1)
            response = self.serial.read(3)
            
            if len(response) == 3 and response[0] == 0xEE and response[1] == 0x0A:
                return True
                
            print("Warning: Device did not respond as expected")
            return False
            
        except Exception as e:
            print(f"Error connecting to energy monitor: {e}")
            if self.serial:
                self.serial.close()
                self.serial = None
            return False
    
    def start_monitoring(self, task_id=None):
        """Start energy monitoring
        
        Args:
            task_id: Identifier for the task (used in the output filename)
            
        Returns:
            bool: True if monitoring started successfully
        """
        if self.running:
            print("Monitoring already running")
            return False
        
        if not self.serial and not self.connect():
            print("Failed to connect to energy monitor")
            return False
        
        self.task_id = task_id or datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        self.samples = []
        self.running = True
        
        # Calculate sleep time (if needed)
        sleep_time = 0
        
        def monitoring_thread():
            try:
                # Prepare serial port for fast sampling
                self.serial.reset_input_buffer()
                
                # Start sampling loop
                while self.running:
                    # Send read command
                    self.serial.write(bytes([self.CMD_READ_SENSOR_VALUES]))
                    
                    # Read response (16 bytes expected)
                    response = self.serial.read(16)

                    if len(response) == 16:
                        sample = self._parse_sample(response)
                        self.samples.append(sample)
                    
                    # Optional sleep to control rate
                    if sleep_time > 0:
                        time.sleep(sleep_time)
                        
            except Exception as e:
                print(f"Monitoring error: {e}")
                self.running = False
        
        self.monitor_thread = threading.Thread(target=monitoring_thread)
        self.monitor_thread.daemon = True
        self.monitor_thread.start()
        
        return True
    
    def stop_monitoring(self, save=True):
        """Stop energy monitoring and optionally save the data
        
        Args:
            save: Whether to save the data to CSV
            
        Returns:
            DataFrame: Collected energy data, or None if no data was collected
        """
        if not self.running:
            print("Monitoring not running")
            return None
        
        self.running = False
        
        # Wait for thread to stop
        if self.monitor_thread:
            self.monitor_thread.join(timeout=1.0)
            self.monitor_thread = None
        
        # Process the collected data
        if not self.samples:
            print("No samples collected")
            return None
        
        df = pd.DataFrame(self.samples)
        
        # Add timestamp relative to start
        start_time = df['timestamp'].min()
        df['elapsed_seconds'] = (df['timestamp'] - start_time).dt.total_seconds()
        
        # Calculate some statistics
        duration = df['elapsed_seconds'].max()
        sample_count = len(df)
        sample_rate = sample_count / duration if duration > 0 else 0
        
        print(f"Collected {sample_count} samples over {duration:.2f} seconds ({sample_rate:.2f} Hz)")
        
        # Save the data if requested
        if save and self.task_id:
            filename = os.path.join(self.output_dir, f"energy_{self.task_id}.csv")
            df.to_csv(filename, index=False)
            print(f"Energy data saved to {filename}")
        
        return df
    
    def get_summary(self, data=None):
        """Get summary statistics of the energy data
        
        Args:
            data: DataFrame of energy data, or None to use the most recent data
            
        Returns:
            dict: Summary statistics
        """
        if data is None:
            if not self.samples:
                return None
            data = pd.DataFrame(self.samples)
        
        summary = {}
        
        # Basic stats
        duration = (data['timestamp'].max() - data['timestamp'].min()).total_seconds()
        summary['duration_seconds'] = duration
        summary['samples_count'] = len(data)
        summary['samples_per_second'] = len(data) / duration if duration > 0 else 0
        
        # Power and energy statistics
        for column in data.columns:
            if column.endswith('_power'):
                # Rename keys to use "energy" terminology
                component = column.replace('_power', '')
                
                # Calculate average power
                avg_power = data[column].mean()
                summary[f'{component}_avg_watts'] = avg_power
                
                # Calculate min/max power
                summary[f'{component}_min_watts'] = data[column].min()
                summary[f'{component}_max_watts'] = data[column].max()
                
                # Calculate energy consumption
                energy_joules = avg_power * duration
                summary[f'{component}_energy_joules'] = energy_joules
                summary[f'{component}_energy_wh'] = energy_joules / 3600
        
        return summary
    
    def print_summary(self, summary=None):
        """Print a formatted summary of the energy monitoring session
        
        Args:
            summary: Summary dict, or None to generate from most recent data
        """
        if summary is None:
            summary = self.get_summary()
            
        if not summary:
            print("No data available for summary")
            return
        
        print("\n=== Energy Monitoring Summary ===")
        print(f"Duration: {summary['duration_seconds']:.2f} seconds")
        print(f"Samples: {summary['samples_count']} ({summary['samples_per_second']:.2f} Hz)")
        
        print("\nEnergy Consumption:")
        
        # Find all energy components (look for keys ending with _energy_wh)
        components = set()
        for key in summary.keys():
            if key.endswith('_energy_wh'):
                component = key.replace('_energy_wh', '')
                components.add(component)
        
        # Print each component's statistics
        for component in sorted(components):
            print(f"  {component}:")
            if f"{component}_avg_watts" in summary:
                print(f"    Average Power: {summary[component + '_avg_watts']:.2f} W")
                print(f"    Min Power: {summary[component + '_min_watts']:.2f} W")
                print(f"    Max Power: {summary[component + '_max_watts']:.2f} W")
            print(f"    Total Energy: {summary[component + '_energy_joules']:.2f} J "
                  f"({summary[component + '_energy_wh']:.4f} Wh)")
    
    def _parse_sample(self, response_bytes):
        """Parse the raw sample data from the device
        
        Args:
            response_bytes: 16-byte response from the CMD_READ_SENSOR_VALUES command
            
        Returns:
            dict: Parsed energy data with timestamps
        """
        data = {}
        data['timestamp'] = datetime.datetime.now()
        
        for i, name in enumerate(self.sensors):
            # First 2 bytes (little endian): voltage * 100
            # Next 2 bytes (little endian): current * 10
            voltage = struct.unpack("<h", response_bytes[i*4:i*4+2])[0] / 100.0
            current = struct.unpack("<h", response_bytes[i*4+2:i*4+4])[0] / 10.0
            power = voltage * current
            
            data[f'{name}_voltage'] = voltage
            data[f'{name}_current'] = current
            data[f'{name}_power'] = power
        
        # Calculate aggregated values
        gpu_power = data['PCIE2_power'] + data['EPS2_power']
        cpu_power = data['EPS1_power']
        total_power = gpu_power + cpu_power
        
        data['GPU_power'] = gpu_power
        data['CPU_power'] = cpu_power
        data['Total_power'] = total_power
        
        return data