import os
from filelock import FileLock
import boto3
import psutil
import logging

from desktop_env.providers.base import VMManager
from botocore.exceptions import ClientError
import random
import time
from desktop_env.providers.aws.aws_config import IMAGE_ID_MAP,INSTANCE_TYPE

logger = logging.getLogger("desktopenv.providers.aws.AWSVMManager")
logger.setLevel(logging.INFO)

REGISTRY_PATH = '.aws_vms'

DEFAULT_REGION = "us-east-1"
# todo: Add doc for the configuration of image, security group and network interface
# todo: public the AMI images

NETWORK_INTERFACE_MAP = {

}


def _allocate_vm(region=DEFAULT_REGION, aws_ami = "reddit", max_retries=5):
    run_instances_params = {
        "MaxCount": 1,
        "MinCount": 1,
        "ImageId": IMAGE_ID_MAP[region][aws_ami],
        "InstanceType": INSTANCE_TYPE[aws_ami],
        "EbsOptimized": True,
        "NetworkInterfaces": NETWORK_INTERFACE_MAP[region]
    }

    ec2_client = boto3.client('ec2', region_name=region)
    
    retries = 0
    while retries <= max_retries:
        try:
            response = ec2_client.run_instances(**run_instances_params)
            instance_id = response['Instances'][0]['InstanceId']
            logger.info(f"Waiting for instance {instance_id} to be running...")
            ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
            logger.info(f"Instance {instance_id} is ready.")
            return instance_id
        except ClientError as e:
            if e.response['Error']['Code'] == 'RequestLimitExceeded':
                wait_time = (2 ** retries) + random.random()
                logger.warning(f"Request limit exceeded, retrying in {wait_time:.2f} seconds...")
                time.sleep(wait_time)
                retries += 1
                if retries > max_retries:
                    logger.error("Maximum retry attempts reached")
                    raise
            else:
                raise

class AWSVMManager(VMManager):
    def __init__(self, aws_ami = None, registry_path=REGISTRY_PATH, ):
        self.aws_ami = aws_ami
        self.registry_path = registry_path
        self.lock = FileLock(".aws_lck", timeout=60)
        self.initialize_registry()

    def initialize_registry(self):
        with self.lock:  # Locking during initialization
            if not os.path.exists(self.registry_path):
                with open(self.registry_path, 'w') as file:
                    file.write('')

    def add_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
        if lock_needed:
            with self.lock:
                self._add_vm(vm_path, region)
        else:
            self._add_vm(vm_path, region)

    def _add_vm(self, vm_path, region=DEFAULT_REGION):
        with open(self.registry_path, 'r') as file:
            lines = file.readlines()
            vm_path_at_vm_region = "{}@{}".format(vm_path, region)
            new_lines = lines + [f'{vm_path_at_vm_region}|free\n']
        with open(self.registry_path, 'w') as file:
            file.writelines(new_lines)

    def delete_vm(self, vm_path, region=DEFAULT_REGION, lock_needed=True):
        if lock_needed:
            with self.lock:
                self._delete_vm(vm_path, region)
        else:
            self._delete_vm(vm_path, region)

    def _delete_vm(self, vm_path, region=DEFAULT_REGION):
        new_lines = []
        with open(self.registry_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                vm_path_at_vm_region, pid_str = line.strip().split('|')
                if vm_path_at_vm_region == "{}@{}".format(vm_path, region):
                    continue
                else:
                    new_lines.append(line)
        with open(self.registry_path, 'w') as file:
            file.writelines(new_lines)

    def occupy_vm(self, vm_path, pid, region=DEFAULT_REGION, lock_needed=True):
        if lock_needed:
            with self.lock:
                self._occupy_vm(vm_path, pid, region)
        else:
            self._occupy_vm(vm_path, pid, region)

    def _occupy_vm(self, vm_path, pid, region=DEFAULT_REGION):
        new_lines = []
        with open(self.registry_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                registered_vm_path, _ = line.strip().split('|')
                if registered_vm_path == "{}@{}".format(vm_path, region):
                    new_lines.append(f'{registered_vm_path}|{pid}\n')
                else:
                    new_lines.append(line)
        with open(self.registry_path, 'w') as file:
            file.writelines(new_lines)

    def check_and_clean(self, lock_needed=True):
        if lock_needed:
            with self.lock:
                self._check_and_clean()
        else:
            self._check_and_clean()

    def _check_and_clean(self):
        # Get active PIDs
        active_pids = {p.pid for p in psutil.process_iter()}

        new_lines = []
        vm_path_at_vm_regions = {}

        with open(self.registry_path, 'r') as file:
            lines = file.readlines()

        # Collect all VM paths and their regions
        for line in lines:
            vm_path_at_vm_region, pid_str = line.strip().split('|')
            vm_path, vm_region = vm_path_at_vm_region.split("@")
            if vm_region not in vm_path_at_vm_regions:
                vm_path_at_vm_regions[vm_region] = []
            vm_path_at_vm_regions[vm_region].append((vm_path_at_vm_region, pid_str))

        # Process each region
        for region, vm_info_list in vm_path_at_vm_regions.items():
            ec2_client = boto3.client('ec2', region_name=region)
            instance_ids = [vm_info[0].split('@')[0] for vm_info in vm_info_list]

            # Batch describe instances
            try:
                response = ec2_client.describe_instances(InstanceIds=instance_ids)
                reservations = response.get('Reservations', [])

                terminated_ids = set()
                stopped_ids = set()
                active_ids = set()

                # Collect states of all instances
                for reservation in reservations:
                    for instance in reservation.get('Instances', []):
                        instance_id = instance.get('InstanceId')
                        instance_state = instance['State']['Name']
                        if instance_state in ['terminated', 'shutting-down']:
                            terminated_ids.add(instance_id)
                        elif instance_state == 'stopped':
                            stopped_ids.add(instance_id)
                        else:
                            active_ids.add(instance_id)

                # Write results back to file
                for vm_path_at_vm_region, pid_str in vm_info_list:
                    vm_path = vm_path_at_vm_region.split('@')[0]

                    if vm_path in terminated_ids:
                        logger.info(f"VM {vm_path} not found or terminated, releasing it.")
                        continue
                    elif vm_path in stopped_ids:
                        logger.info(f"VM {vm_path} stopped, mark it as free")
                        new_lines.append(f'{vm_path}@{region}|free\n')
                        continue

                    if pid_str == "free":
                        new_lines.append(f'{vm_path}@{region}|{pid_str}\n')
                    elif int(pid_str) in active_pids:
                        new_lines.append(f'{vm_path}@{region}|{pid_str}\n')
                    else:
                        new_lines.append(f'{vm_path}@{region}|free\n')

            except ec2_client.exceptions.ClientError as e:
                if 'InvalidInstanceID.NotFound' in str(e):
                    logger.info(f"VM not found, releasing instances in region {region}.")
                    continue

        # Writing updated lines back to the registry file
        with open(self.registry_path, 'w') as file:
            file.writelines(new_lines)

        # We won't check and clean on the files on aws and delete the unregistered ones
        # Since this can lead to unexpected delete on other server
        # PLease do monitor the instances to avoid additional cost

    def list_free_vms(self, region=DEFAULT_REGION, lock_needed=True):
        if lock_needed:
            with self.lock:
                return self._list_free_vms(region)
        else:
            return self._list_free_vms(region)

    def _list_free_vms(self, region=DEFAULT_REGION):
        free_vms = []
        with open(self.registry_path, 'r') as file:
            lines = file.readlines()
            for line in lines:
                vm_path_at_vm_region, pid_str = line.strip().split('|')
                vm_path, vm_region = vm_path_at_vm_region.split("@")
                if pid_str == "free" and vm_region == region:
                    free_vms.append((vm_path, pid_str))

        return free_vms

    def get_vm_path(self, region=DEFAULT_REGION):
        with self.lock:
            if not AWSVMManager.checked_and_cleaned:
                AWSVMManager.checked_and_cleaned = True
                self._check_and_clean()

        allocation_needed = False
        with self.lock:
            free_vms_paths = self._list_free_vms(region)

            if len(free_vms_paths) == 0:
                # No free virtual machine available, generate a new one
                allocation_needed = True
            else:
                # Choose the first free virtual machine
                chosen_vm_path = free_vms_paths[0][0]
                self._occupy_vm(chosen_vm_path, os.getpid(), region)
                return chosen_vm_path
        
        if allocation_needed:
            logger.info("No free virtual machine available. Generating a new one, which would take a while...☕")
            
            with self.lock:
                new_vm_path = _allocate_vm(region, self.aws_ami)
                self._add_vm(new_vm_path, region)
                self._occupy_vm(new_vm_path, os.getpid(), region)
            return new_vm_path
