import argparse
import sys
import os
import torch
import torch.distributed as dist
import datetime

# 1) The directory this file lives in:
here = os.path.dirname(__file__)  
project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
sys.path.append(project_root)


from BackdoorObjectDetection.bd_models.models.build import build_model
from BackdoorObjectDetection.bd_models.utils.load_utils import initialize_loaders

def write_log(metrics, save_path, file_name, position):

    # Check if save_path + file_name exists, if not create it
    if not os.path.exists(os.path.join(save_path, file_name)):
        
        # Create a file with save_path + file_name
        # Write the header of a CSV file
        with open(os.path.join(save_path, file_name), 'w') as f:
            f.write("position,")
            f.write(",".join(metrics.keys()) + "\n")

    # Append the metrics to the file
    with open(os.path.join(save_path, file_name), 'a') as f:
        f.write(f"{position},")
        f.write(",".join([str(metrics[key]) for key in metrics.keys()]) + "\n")

def test_model(args, model_wrapper, evaluator, save_path, distributed=False, rank=0, world_size=1):
    
    print(f'[Rank {rank}] [TrainModel] Starting model evaluation')

    # Run all the evaluators of type 'test'
    for eval in evaluator:
        if eval['type'] == 'test':

            # If position is none, skip it
            if eval['position'] == 'none':
                print(f'[Rank {rank}] [TrainModel] Skipping evaluation for {eval["name"]} at position {eval["position"]}')
                continue
            
            print(f'[Rank {rank}] [TrainModel] Evaluating on {eval["name"]} test loader position {eval["position"]}')
            metrics = eval['evaluator'].evaluate(eval['loader'], eval['loader'].dataset.bbox_return_format, save_path, model_wrapper.current_epoch)

            # Save the metrics 
            if rank == 0:
                write_log(metrics, save_path, f"test_metrics_multi_pos.csv", eval['position'])

                # Print the output of the metrics
                print(f'[Rank {rank}] [TrainModel] {eval["name"]} Test Metrics: {metrics} at position {eval["position"]}')

        if distributed:
            torch.cuda.empty_cache()
            dist.barrier()

def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="BD Model Training")

    # Arguments for the training parameters
    parser.add_argument("--record_path", required=True, type=str, help="Record path")
    parser.add_argument("--recursive", action='store_true', help="Whether to recursively run the folders in the record path")
    parser.add_argument("--force_inference", action='store_true', help="Whether to force inference on the dataset even if metrics already exist")

    args = parser.parse_args()

    return args

def evaluate_mtsd(args, path, device, distributed=False, rank=0, world_size=1, local_rank=0):

    # Read the args.txt file inside the path
    args_file = os.path.join(path, 'args.txt')
    if not os.path.exists(args_file):
        raise FileNotFoundError(f"[TestModel] The specified args file does not exist: {args_file}")
    
    # Save each line (represented as KEY: VALUE) as an attribute of temp_args object
    temp_args = argparse.Namespace()
    with open(args_file, 'r') as f:
        for line in f:
            if ':' in line:
                key, value = line.split(':', 1)
                key = key.strip()
                value = value.strip()
                setattr(temp_args, key, value)

    temp_args.multi_position = ['low', 'high', 'both', 'random']
    temp_args.use_p_ratio = None
    temp_args.use_lambda = None

    # Turn num_workers into an integer
    if hasattr(temp_args, 'num_workers'):
        try:
            temp_args.num_workers = int(temp_args.num_workers)
        except ValueError:
            raise ValueError(f"[TestModel] num_workers must be an integer, got {temp_args.num_workers}")
        
    # Turn batch_size into an integer
    if hasattr(temp_args, 'batch_size'):
        try:
            temp_args.batch_size = int(temp_args.batch_size)
        except ValueError:
            raise ValueError(f"[TestModel] batch_size must be an integer, got {temp_args.batch_size}")

    # Build the Model
    model_wrapper = build_model(temp_args.model, 'mtsd_meta', temp_args.model_config_path, device, path, distributed=distributed, local_rank=local_rank)
    
    # Build the evaluators
    _, evaluators = initialize_loaders(
        temp_args, 
        model_wrapper,
        temp_args.model,
        distributed=distributed,
        rank=rank,
        world_size=world_size,
    )

    for eval in evaluators:
        print(f'[TestModel] Evaluator: {eval["name"]} with position {eval["position"]}')

    if distributed:
        dist.barrier()
    
    test_model(args, model_wrapper, evaluators, path, distributed=distributed, rank=rank, world_size=world_size)

def main():

    using_torchrun = all(
        var in os.environ for var in ("RANK", "WORLD_SIZE", "LOCAL_RANK")
    )

    if using_torchrun:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        distributed = world_size > 1
    else:
        # Fallback for plain python execution (debug / CPU run)
        rank = local_rank = 0
        world_size = 1
        distributed = False

    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device(local_rank)
    else:
        device = torch.device("cpu")

    if distributed:
        # torchrun already sets MASTER_ADDR / MASTER_PORT, so just use env://
        dist.init_process_group(backend="nccl", init_method="env://")

    print(f'[Rank {rank}] [TrainModel] Distributed training: {distributed}, World size: {world_size}, Local rank: {local_rank}')
    print(f'[Rank {rank}] [TrainModel] Device set to: {device}')

    # Step 3: parse your own arguments
    args = parse_args()

    # if recursive, run the folders in the record path
    if args.recursive:
        print(f'[TestModel] Running recursively in the record path: {args.record_path}')

        dirs = os.listdir(args.record_path)
        for i, dir in enumerate(dirs):

            print(f'[TestModel] Processing directory {dir} (#{i + 1}/{len(dirs)})')

            # If the directory starts with baseline, skip it
            if dir.startswith('baseline'):
                print(f'[TestModel] Skipping baseline directory: {full_path}')
                continue

            # If the directory ends with _OLD or _OLD2, skip it
            if dir.endswith('_OLD') or dir.endswith('_OLD2'):
                print(f'[TestModel] Skipping old directory: {full_path}')
                continue
            
            full_path = os.path.join(args.record_path, dir)
            print(f'[TestModel] Full path: {full_path}')
            
            if os.path.isdir(full_path):
                print(f'[TestModel] Running on the full path: {full_path}')

                # If it doesnt contain a checkpoint.pth file the raise an error
                checkpoint_path = os.path.join(full_path, 'checkpoint.pth')
                if not os.path.exists(checkpoint_path):
                    print(f'[TestModel] No checkpoint found in {full_path}, skipping this directory.')
                    continue

                # Check if directory already contains a test_metrics_multi_pos.csv file
                metrics_file = os.path.join(full_path, 'test_metrics_multi_pos.csv')
                if os.path.exists(metrics_file) and not args.force_inference:
                    print(f'[TestModel] Skipping inference for {full_path} as metrics already exist and force_inference is not set.')
                    continue

                if args.force_inference:
                    # Remove the test_metrics_multi_pos.csv file if it exists
                    if os.path.exists(metrics_file) and rank == 0:
                        print(f'[TestModel] Removing existing metrics file: {metrics_file}')
                        os.remove(metrics_file)

                # Evaluate the MTSD dataset
                evaluate_mtsd(args, full_path, device, distributed=distributed, rank=rank, world_size=world_size, local_rank=local_rank)

            if distributed:
                dist.barrier()

    else:
        print(f'[TestModel] Running on the record path: {args.record_path}')
        
        # If it doesnt contain a checkpoint.pth file the raise an error
        checkpoint_path = os.path.join(args.record_path, 'checkpoint.pth')
        if not os.path.exists(checkpoint_path):
            print(f'[TestModel] No checkpoint found in {args.record_path}, skipping this directory.')
            return
    
        # Check if directory already contains a test_metrics_multi_pos.csv file
        metrics_file = os.path.join(args.record_path, 'test_metrics_multi_pos.csv')
        if os.path.exists(metrics_file) and not args.force_inference:
            print(f'[TestModel] Skipping inference for {args.record_path} as metrics already exist and force_inference is not set.')
            return
        
        if args.force_inference:
            # Remove the test_metrics_multi_pos.csv file if it exists
            if os.path.exists(metrics_file):
                print(f'[TestModel] Removing existing metrics file: {metrics_file}')
                os.remove(metrics_file)
        
        # Evaluate the MTSD dataset
        evaluate_mtsd(args, args.record_path, device)

if __name__ == "__main__":
    main()
