import os
import sys
from pathlib import Path
from azure.storage.blob import BlobServiceClient

class ArtifactsSyncHandler:
    ARTIFACTS_CONTAINER = 'exp-artifacts'
    RESULTS_CONTAINER = 'exp-results'
    AZURE_BLOB_URL = "https://aminefileshare.blob.core.windows.net"

    def __init__(self, key: str) -> None:
        self.client = BlobServiceClient(account_url=self.AZURE_BLOB_URL, credential=key)
        
        self.exp_artifacts_path = Path(os.environ['HOME']) / 'exp_artifacts' 
        self.exp_results_path = Path(os.environ['HOME']) / 'exp_results'

    def download_exp_artifacts(self, exp_name: str):
        exp_folder = self.exp_artifacts_path / exp_name
        if exp_folder.exists():
            print(f"Experiment {exp_name} already exists in local storage")
            return

        prefix = f'{exp_name}/'
        files_to_download = self.get_prefix_blob_names(prefix, self.ARTIFACTS_CONTAINER)
        if not files_to_download:
            print(f"No artifacts found for experiment {exp_name}")
            return
        
        exp_folder.mkdir()
        container_client = self.client.get_container_client(container=self.ARTIFACTS_CONTAINER) 
        for blob_name in files_to_download:
            file_name = blob_name.split('/')[-1]
            file_path = exp_folder / file_name
            print(f"Downloading file {file_path}")
            with file_path.open(mode="wb") as download_file:
                download_file.write(container_client.download_blob(blob_name).readall())
                
    def dowload_exp_results(self, exp_name: str):
        prefix = f'{exp_name}/'
        test_tasks = self.get_prefix_blob_names(prefix, self.RESULTS_CONTAINER)
        if not test_tasks:
            print(f"No results found for experiment {exp_name}")
            return
        
        exp_path = self.exp_results_path / exp_name
        exp_path.mkdir(exist_ok=True)
        
        container_client = self.client.get_container_client(container=self.RESULTS_CONTAINER) 
        for test_task in test_tasks:
            test_task_files = self.get_prefix_blob_names(test_task, self.RESULTS_CONTAINER)
            for blob_name in test_task_files:
                if blob_name.endswith('/'):
                    if 'json' not in blob_name or 'csv' not in blob_name:
                        # Error
                        print(f"Error: {blob_name} is not a valid file")
                        continue

                    # Folder
                    sub_folder = self.exp_results_path / Path(blob_name)
                    sub_folder.mkdir(exist_ok=True, parents=True)

                    sub_folder_files = self.get_prefix_blob_names(blob_name, self.RESULTS_CONTAINER)
                    for sub_blob_name in sub_folder_files:
                        sub_file = self.exp_results_path / sub_blob_name 
                        if sub_file.exists(): 
                            print(f"File {blob_name} already exists in local storage")
                            continue

                        print(f"Downloading file {sub_file}")
                        with sub_file.open(mode="wb") as download_file:
                            download_file.write(container_client.download_blob(sub_blob_name).readall())
                else:
                    parts = blob_name.split('/')
                    test_task_folder = self.exp_results_path / '/'.join(parts[:-1])
                    test_task_folder.mkdir(exist_ok=True)
                    file_path = self.exp_results_path / blob_name
                    if file_path.exists(): 
                        print(f"File {blob_name} already exists in local storage")
                        continue

                    print(f"Downloading file {file_path}")
                    with file_path.open(mode="wb") as download_file:
                        download_file.write(container_client.download_blob(blob_name).readall())

    def res_path_to_blob_name(self, path: Path):
        blob_name = path.relative_to(self.exp_results_path).as_posix()
        if path.is_dir():
            blob_name += '/'
        return blob_name

    def sync_exp_results(self):
        self.exp_updat_cnt = 0
        def _sync_folder(_folder_path: Path):
            upload_cnt = 0 
            blob_name = self.res_path_to_blob_name(_folder_path) 
            existing_files = self.get_prefix_blob_names(blob_name, self.RESULTS_CONTAINER)
            for item in _folder_path.iterdir():
                if item.is_dir():
                    # Subfolder
                    upload_cnt += _sync_folder(item)
                else:
                    # File
                    file_blob_name = self.res_path_to_blob_name(item) 
                    if file_blob_name not in existing_files: 
                        self.upload_blob_file(self.RESULTS_CONTAINER, item, file_blob_name)
                        print(f"Uploaded {file_blob_name} to Azure Blob Storage")
                        
                        upload_cnt += 1
            return upload_cnt

        total_upload_cnt = 0
        for exp in self.exp_results_path.iterdir():
            exp_name = exp.name
            for test_task in exp.iterdir():
                test_task_name = test_task.name
                print(f"Syncing {exp_name} results for {test_task_name}")
                total_upload_cnt = _sync_folder(test_task)
                if total_upload_cnt > 0:
                    self.exp_updat_cnt += 1
                
        return self.exp_updat_cnt, total_upload_cnt 

    def sync_exp_artifacts(self):
        total_cnt = 0
        upload_cnt = 0
        exp_cnt = 0
        total_exp_cnt = 0
        for exp_artifact_path in self.exp_artifacts_path.iterdir():
            upload_done = False
            exp_name = exp_artifact_path.name
            if exp_name.endswith('merged'):
                continue

            total_exp_cnt += 1
            print(f"Syncing {exp_name} artifacts")
            existing_files = self.get_exp_artifacts_blob_names(exp_name)
            for file in exp_artifact_path.iterdir():
                upload_done = True 
                total_cnt += 1
                blob_name = self.create_blob_name(exp_name, file)
                if blob_name not in existing_files: 
                    self.upload_blob_file(self.ARTIFACTS_CONTAINER, file, blob_name)
                    upload_cnt += 1
                else:
                    print(f"File {blob_name} already exists in Azure Blob Storage") 
                    
            if upload_done:
                exp_cnt += 1
                    
        print(f"Uploaded {upload_cnt} files out of {total_cnt} files is done!")
        
        return exp_cnt, total_exp_cnt
                
    def upload_blob_file(self, container_name: str, file_path: Path, blob_name: str):
        container_client = self.client.get_container_client(container=container_name)
        with file_path.open(mode="rb") as data:
            container_client.upload_blob(name=blob_name, data=data, overwrite=False)

    def create_blob_name(self, exp_name: str, file_path: Path, subfolder_list = None):
        blob_name = f"{exp_name}/{file_path.name}"
        if subfolder_list:
            blob_name = f"{exp_name}/{'/'.join(subfolder_list)}/{file_path.name}"
        return blob_name
            
    def upload_exp_artifacts(self, exp_artifact_path: Path):
        exp_name = exp_artifact_path.name
        for file in exp_artifact_path.iterdir():
            blob_name = self.create_blob_name(exp_name, file)
            self.upload_blob_file(self.ARTIFACTS_CONTAINER, file, blob_name)
        print(f'Uploading {exp_name} artifacts to Azure Blob Storage is done!')
        
    def upload_exp_results(self, exp_results_path: Path):
        exp_name = exp_results_path.name
        for file in exp_results_path.iterdir():
            blob_name = self.create_blob_name(exp_name, file)
            self.upload_blob_file(self.RESULTS_CONTAINER, file, blob_name)
                
        print(f'Uploading {exp_name} results to Azure Blob Storage is done!')

    def get_prefix_blob_names(self, prefix: str, container_name: str):
        container_client = self.client.get_container_client(container=container_name)
        blob_names = []
        for blob in container_client.walk_blobs(name_starts_with=prefix, delimiter='/'):
            blob_names.append(blob.name)
        return blob_names
                    


    def get_exp_artifacts_blob_names(self, exp_name: str): 
        prefix = f'{exp_name}/'
        return self.get_prefix_blob_names(prefix, self.ARTIFACTS_CONTAINER)
    

def main():
    credential = os.environ['DOG'] 
    artifact_sync_handler = ArtifactsSyncHandler(credential)
    
    if len(sys.argv) > 1:
        exp_name = sys.argv[1]
        artifact_sync_handler.download_exp_artifacts(exp_name)
        artifact_sync_handler.dowload_exp_results(exp_name)
        print(f'Downloaded {exp_name} artifacts to local storage')
    else:
        res_uploaded_exp_cnt, total_res = artifact_sync_handler.sync_exp_results()
        art_uploaded_exp_cnt, total_art = artifact_sync_handler.sync_exp_artifacts()
        print('Done!')
        
        print('-' * 50)
        print('-' * 50)
        print('Summary:')
        print(f'Uploaded {res_uploaded_exp_cnt} out of {total_res} experiments results')
        print(f'Uploaded {art_uploaded_exp_cnt} out of {total_art} experiments artifacts')
        print('-' * 50)
    
    
    
if __name__ == "__main__":
    main()