#!/usr/bin/env python3
import os
import json
import shutil
import torch
import tomlkit 
import random
import numpy as np
import datetime
import time
import pickle
import sys
import argparse
from pathlib import Path
from typing import Dict, Any, Union, Optional, List, Tuple
from copy import deepcopy

try:
    import pytomlpp as toml_modern
    HAS_PYTOMLPP = True
except ImportError:
    HAS_PYTOMLPP = False
    toml_modern = None

try:
    import pynvml
    HAS_PYNVML = True
except ImportError:
    HAS_PYNVML = False
    pynvml = None


REGRESSION = 'regression'
BINCLASS = 'binclass'
MULTICLASS = 'multiclass'


def load_json(path: Union[Path, str]) -> Any:
    return json.loads(Path(path).read_text())


def dump_json(data: Dict[str, Any], file_path: Union[str, Path], indent: int = 2) -> None:
 
    def tensor_to_serializable(obj):
        """Convert torch.Tensor to serializable format"""
        if isinstance(obj, torch.Tensor):
            return obj.tolist()
        return obj
    
    Path(file_path).write_text(
        json.dumps(data, indent=indent, default=tensor_to_serializable, ensure_ascii=False) + '\n'
    )


def load_toml(path: Union[Path, str]) -> Any:
    if HAS_PYTOMLPP:
        return toml_modern.loads(Path(path).read_text())
    else:
        return tomlkit.loads(Path(path).read_text())


def dump_toml(data: Any, path: Union[Path, str]) -> None:
    if HAS_PYTOMLPP:
        Path(path).write_text(toml_modern.dumps(data) + '\n')
    else:
        with open(path, 'w', encoding='utf-8') as f:
            tomlkit.dump(data, f)


def load_pickle(path: Union[Path, str]) -> Any:
    return pickle.loads(Path(path).read_bytes())


def dump_pickle(data: Any, path: Union[Path, str]) -> None:
    Path(path).write_bytes(pickle.dumps(data))


def load_config(config_path: str) -> Dict[str, Any]:
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    return load_toml(config_path)


def load_config_with_args(
    argv: Optional[List[str]] = None,
) -> Tuple[Dict[str, Any], Path]:

    parser = argparse.ArgumentParser()
    parser.add_argument('config', metavar='FILE')
    parser.add_argument('-o', '--output', metavar='DIR')
    parser.add_argument('-f', '--force', action='store_true')
    parser.add_argument('--continue', action='store_true', dest='continue_')
    if argv is None:
        argv = sys.argv[1:]
    args = parser.parse_args(argv)

    snapshot_dir = os.environ.get('SNAPSHOT_PATH')
    if snapshot_dir and Path(snapshot_dir).joinpath('CHECKPOINTS_RESTORED').exists():
        assert args.continue_

    config_path = Path(args.config).absolute()
    output_dir = (
        Path(args.output)
        if args.output
        else config_path.parent.joinpath(config_path.stem)
    ).absolute()
    sep = '=' * (8 + max(len(str(config_path)), len(str(output_dir))))
    print(sep, f'Config: {config_path}', f'Output: {output_dir}', sep, sep='\n')

    assert config_path.exists()
    config = load_toml(config_path)

    if output_dir.exists():
        if args.force:
            print('Removing the existing output and creating a new one...')
            shutil.rmtree(output_dir)
            output_dir.mkdir()
        elif not args.continue_:
            backup_output(output_dir)
            print('Already done!\n')
            sys.exit()
        elif output_dir.joinpath('DONE').exists():
            backup_output(output_dir)
            print('Already DONE!\n')
            sys.exit()
        else:
            print('Continuing with the existing output...')
    else:
        print('Creating the output...')
        output_dir.mkdir()

    environment: Dict[str, Any] = {}
    if torch.cuda.is_available():
        cvd = os.environ.get('CUDA_VISIBLE_DEVICES')
        if HAS_PYNVML:
            try:
                pynvml.nvmlInit()
                environment['devices'] = {
                    'CUDA_VISIBLE_DEVICES': cvd,
                    'torch.version.cuda': torch.version.cuda,
                    'torch.backends.cudnn.version()': torch.backends.cudnn.version(),
                    'driver': str(pynvml.nvmlSystemGetDriverVersion(), 'utf-8'),
                }
                try:
                    environment['devices']['torch.cuda.nccl.version()'] = torch.cuda.nccl.version()
                except (AttributeError, RuntimeError):
                    environment['devices']['torch.cuda.nccl.version()'] = 'N/A'
                
                if cvd:
                    for i in map(int, cvd.split(',')):
                        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                        environment['devices'][i] = {
                            'name': str(pynvml.nvmlDeviceGetName(handle), 'utf-8'),
                            'total_memory': pynvml.nvmlDeviceGetMemoryInfo(handle).total,
                        }
            except Exception as e:
                environment['pynvml_error'] = str(e)

    dump_stats({'config': config, 'environment': environment}, output_dir)
    return config, output_dir


def save_config(config: Dict[str, Any], config_path: str) -> None:
    dump_toml(config, config_path)


def ensure_dir(path: Union[str, Path]) -> Path:
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path


def backup_output(output_dir: Path, backup_suffix: str = "_backup") -> None:
    if not output_dir.exists():
        return
    
    backup_dir = output_dir.with_name(output_dir.name + backup_suffix)
    if backup_dir.exists():
        shutil.rmtree(backup_dir)
    
    shutil.copytree(output_dir, backup_dir)

def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')


def get_environment_info() -> Dict[str, Any]:
    environment = {
        'python_version': f"{sys.version}",
        'torch_version': torch.__version__,
        'device': str(get_device()),
        'timestamp': datetime.datetime.now().isoformat(),
    }

    if torch.cuda.is_available():
        cvd = os.environ.get('CUDA_VISIBLE_DEVICES')
        environment['cuda'] = {
            'CUDA_VISIBLE_DEVICES': cvd,
            'torch.version.cuda': torch.version.cuda,
            'torch.backends.cudnn.version()': torch.backends.cudnn.version(),
        }

        try:
            environment['cuda']['torch.cuda.nccl.version()'] = torch.cuda.nccl.version()
        except (AttributeError, RuntimeError):
            environment['cuda']['torch.cuda.nccl.version()'] = 'N/A'

        if HAS_PYNVML:
            try:
                pynvml.nvmlInit()
                environment['cuda']['driver'] = str(pynvml.nvmlSystemGetDriverVersion(), 'utf-8')
                
                if cvd:
                    environment['cuda']['devices'] = {}
                    for i in map(int, cvd.split(',')):
                        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                        environment['cuda']['devices'][i] = {
                            'name': str(pynvml.nvmlDeviceGetName(handle), 'utf-8'),
                            'total_memory': pynvml.nvmlDeviceGetMemoryInfo(handle).total,
                        }
            except Exception as e:
                environment['cuda']['pynvml_error'] = str(e)
    
    return environment

def set_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def format_seconds(seconds: float) -> str:
    return str(datetime.timedelta(seconds=round(seconds)))


def dump_stats(stats: dict, output_dir: Path, final: bool = False) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)

    dump_json(stats, output_dir / 'stats.json', indent=4)
    json_output_path = os.environ.get('JSON_OUTPUT_FILE')
    
    if final:
        output_dir.joinpath('DONE').touch()
        
        if json_output_path:
            try:
                project_dir = os.environ.get('PROJECT_DIR')
                if project_dir:
                    key = str(output_dir.relative_to(Path(project_dir)))
                else:
                    key = str(output_dir.name)
            except ValueError:
                key = str(output_dir.name)
            
            json_output_path = Path(json_output_path)
            
            try:
                json_data = load_json(json_output_path)
            except (FileNotFoundError, json.decoder.JSONDecodeError):
                json_data = {}

            json_data[key] = stats
            dump_json(json_data, json_output_path)
            snapshot_path = os.environ.get('SNAPSHOT_PATH')
            if snapshot_path:
                snapshot_json_path = os.path.join(snapshot_path, 'json_output.json')
                shutil.copyfile(json_output_path, snapshot_json_path)


def create_experiment_stats(
    dataset_info: Dict[str, Any],
    best_params: Dict[str, Any],
    best_value: float,
    test_results: Dict[str, Any],
    study_statistics: Dict[str, Any],
    training_time: float = None,
    tuning_time: float = None,
    include_environment: bool = True
) -> Dict[str, Any]:
    stats = {
        'dataset_info': dataset_info,
        'hyperparameters': best_params,
        'performance': {
            'best_validation_loss': float(best_value),
            **test_results
        },
        'study_statistics': study_statistics,
        'timing': {}
    }
    
    if training_time is not None:
        stats['timing']['training_time_seconds'] = training_time
        stats['timing']['training_time_formatted'] = format_seconds(training_time)
    
    if tuning_time is not None:
        stats['timing']['tuning_time_seconds'] = tuning_time
        stats['timing']['tuning_time_formatted'] = format_seconds(tuning_time)
    
    if include_environment:
        stats['environment'] = get_environment_info()
    
    return stats

def print_experiment_summary(
    dataset_name: str,
    dataset_id: int,
    output_dir: str,
    best_model_path: str,
    best_config_path: str,
    results_path: str,
    test_results: Dict[str, Any],
    best_logs: Dict[str, Any] = None,
    timing_info: Dict[str, Any] = None
) -> None:

    print("\n" + "="*60)
    print("Experiment Complete!")
    print("="*60)
    print(f"Dataset: {dataset_name} (ID: {dataset_id})")
    print(f"Output Directory: {output_dir}")
    print(f"Best Model: {best_model_path}")
    print(f"Best Config: {best_config_path}")
    print(f"Detailed Results: {results_path}")
    
    print(f"\nKey Metrics:")
    print(f"Test Accuracy: {test_results['test_accuracy']:.4f}")
    print(f"Test F1-Score: {test_results['test_f1']:.4f}")
    print(f"Test Precision: {test_results['test_precision']:.4f}")
    print(f"Test Recall: {test_results['test_recall']:.4f}")
    
    if test_results.get('test_auc_roc') is not None:
        print(f"   Test AUC-ROC: {test_results['test_auc_roc']:.4f}")
    
    if best_logs and 'val_loss' in best_logs:
        print(f"   Best Validation Loss: {min(best_logs['val_loss']):.6f}")
    
    print(f"   Test MSE: {test_results['test_MSE']:.6f}")

    if timing_info:
        print(f"\n⏱Timing Statistics:")
        if 'tuning_time_seconds' in timing_info:
            print(f"   Tuning Time: {format_seconds(timing_info['tuning_time_seconds'])}")
        if 'training_time_seconds' in timing_info:
            print(f"   Training Time: {format_seconds(timing_info['training_time_seconds'])}")

def merge_defaults(kwargs: dict, default_kwargs: dict) -> dict:
    result = deepcopy(default_kwargs)
    result.update(kwargs)
    return result


def raise_unknown(unknown_what: str, unknown_value: Any) -> None:
    raise ValueError(f'Unknown {unknown_what}: {unknown_value}')

if __name__ == "__main__":
    set_seeds(42)
    device = get_device()
    env_info = get_environment_info()
    test_dir = ensure_dir("test_output")

    test_data = {
        'test_key': 'test_value',
        'test_number': 42,
        'test_tensor': torch.tensor([1, 2, 3]),
        'test_time': format_seconds(3661.5)
    }
    dump_json(test_data, test_dir / 'test.json')
    test_config = {
        'model': {'param1': 1, 'param2': 2},
        'training': {'epochs': 100, 'lr': 0.001}
    }
    save_config(test_config, test_dir / 'test.toml')
    loaded_config = load_config(test_dir / 'test.toml')

    shutil.rmtree(test_dir)
