import argparse
import sys
import os
import torch
import torch.distributed as dist
import datetime

# 1) The directory this file lives in:
here = os.path.dirname(__file__)  
project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
sys.path.append(project_root)

from BackdoorObjectDetection.bd_models.models.build import build_model
from BackdoorObjectDetection.bd_models.utils.load_utils import initialize_loaders

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

def write_log(metrics, save_path, file_name, epoch):

    # Check if save_path + file_name exists, if not create it
    if not os.path.exists(os.path.join(save_path, file_name)):
        
        # Create a file with save_path + file_name
        # Write the header of a CSV file
        with open(os.path.join(save_path, file_name), 'w') as f:
            f.write("epoch,")
            f.write(",".join(metrics.keys()) + "\n")

    # Append the metrics to the file
    with open(os.path.join(save_path, file_name), 'a') as f:
        f.write(f"{epoch},")
        f.write(",".join([str(metrics[key]) for key in metrics.keys()]) + "\n")

def test_model(args, model_wrapper, evaluator, save_path, is_test=True, distributed=False, rank=0, world_size=1):
    
    print(f'[Rank {rank}] [TrainModel] Starting model evaluation')

    if is_test:
        type_str = 'test'
    else:
        type_str = 'val'

    # Run all the evaluators of type 'test'
    for eval in evaluator:
        if eval['type'] == type_str:
            print(f'[Rank {rank}] [TrainModel] Evaluating on {eval["name"]} test loader')

            eval_save_path = os.path.join(save_path, eval['name'])
            metrics = eval['evaluator'].evaluate(eval['loader'], eval['loader'].dataset.bbox_return_format, eval_save_path, model_wrapper.current_epoch)

            # Save the metrics 
            if rank == 0:
                write_log(metrics, eval_save_path, f"{type_str}_metrics.csv", model_wrapper.current_epoch + 1)

                # Print the output of the metrics
                print(f'[Rank {rank}] [TrainModel] {eval["name"]} Test Metrics: {metrics}')

        if distributed:
            torch.cuda.empty_cache()
            dist.barrier()

def train_model(args, model_wrapper, train_loader, evaluator, save_path, distributed=False, rank=0, world_size=1):
    
    for epoch in range(model_wrapper.current_epoch, model_wrapper.epochs):

        print(f'[Rank {rank}] [TrainModel] Starting epoch {epoch + 1}/{model_wrapper.epochs}')

        if distributed:
            train_loader.sampler.set_epoch(epoch)
            torch.cuda.empty_cache()

        # Train the model
        losses = model_wrapper.train_one_epoch(train_loader, epoch)

        # Write the losses to disk
        if rank == 0:
            write_log(losses, save_path, "losses.csv", epoch + 1)
            print(f'[Rank {rank}] [TrainModel] Epoch {epoch + 1} losses: {losses}')

        if distributed:
            dist.barrier()

        test_model(args, model_wrapper, evaluator, save_path, is_test=False, distributed=distributed, rank=rank, world_size=world_size)

        if rank == 0:
            model_wrapper.__save_model__()

        if distributed:
            dist.barrier()

def parse_args(distributed=False, rank=0, world_size=1):
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="BD Model Training")

    # Arguments for the dataset
    parser.add_argument("--data_attack", required=True, type=str, help="Attack name (e.g., 'oda_align_random', 'oda_align_fixed', 'oda_tba', 'oda_uba', 'oda_suba')")
    parser.add_argument("--test_attack", required=True, type=str, help="Attack name for testing (e.g., 'rma', 'oda_targeted', 'oda_untargeted')")
    parser.add_argument("--trigger_position", required=True, type=str, help="Trigger position (e.g., 'random', 'center', 'high', 'low')")
    parser.add_argument("--trigger_type", required=True, type=str, help="Trigger type (e.g., 'square')")

    parser.add_argument(
        "--use_p_ratio",
        required=True,
        type=str2bool,
        help="Whether to use poisoning ratio (true/false/yes/no/0/1)",
    )

    parser.add_argument("--p_ratio", type=int, help="Poisoning ratio (e.g., 1, 5, 10, 20, ..., 100)")
    
    parser.add_argument("--dataset", required=True, type=str, help="Dataset name")
    parser.add_argument("--bd_base_path", required=True, type=str, help="Base path for dataset")
    parser.add_argument("--batch_size", default=16, type=int, help="Batch size total (will be divided by world_size if distributed)")
    parser.add_argument("--num_workers", default=4, type=int, help="Number of data loader workers")

    # Arguments for the model
    parser.add_argument("--model", required=True, type=str, help="Model name")
    parser.add_argument("--model_config_path", required=True, type=str, help="Model config path")

    # Arguments for the training parameters
    parser.add_argument("--record_path", required=True, type=str, help="Record path")
    parser.add_argument("--save_dir", required=True, type=str, help="Directory to save the model checkpoints and logs")

    args = parser.parse_args()

    # Check that bd_dataset_path exists
    if not os.path.exists(args.bd_base_path):
        raise FileNotFoundError(f"[Rank {rank}] [TrainModel] The specified bd_base_path does not exist: {args.bd_base_path}")

    # Create a save_path that is record_path/args.save_dir
    args.save_path = os.path.join(args.record_path, args.save_dir)
    args.multi_position = None
    if rank == 0:
        print(f'[Rank {rank}] [TrainModel] Creating save path: {args.save_path}')
        if not os.path.exists(args.save_path):
            os.makedirs(args.save_path)

        # Save the args to a file in the save_path
        args_file = os.path.join(args.save_path, "args.txt")
        with open(args_file, 'w') as f:
            for key, value in vars(args).items():
                f.write(f"{key}: {value}\n")

    return args

def main():

    using_torchrun = all(
        var in os.environ for var in ("RANK", "WORLD_SIZE", "LOCAL_RANK")
    )

    if using_torchrun:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        distributed = world_size > 1
    else:
        # Fallback for plain python execution (debug / CPU run)
        rank = local_rank = 0
        world_size = 1
        distributed = False

    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device(local_rank)
    else:
        raise EnvironmentError("This script requires at least one GPU.")

    if distributed:
        # torchrun already sets MASTER_ADDR / MASTER_PORT, so just use env://
        dist.init_process_group(backend="nccl", init_method="env://")

    print(f'[Rank {rank}] [TrainModel] Distributed training: {distributed}, World size: {world_size}, Local rank: {local_rank}')
    print(f'[Rank {rank}] [TrainModel] Device set to: {device}')

    if distributed:
        print(f'[Rank {rank}] [TrainModel] Initializing process group for distributed training')
        dist.barrier()
    
    # Step 3: parse your own arguments
    args = parse_args(distributed=distributed, rank=rank, world_size=world_size)

    if distributed:
        dist.barrier()

    # Step 4: Build the Model
    print(f'[Rank {rank}] [TrainModel] Building model with config: {args.model_config_path}')
    model_wrapper = build_model(args.model, args.dataset, args.model_config_path, device, args.save_path, distributed=distributed, local_rank=local_rank)

    # Print the model 
    print(f'[Rank {rank}] [TrainModel] Model built successfully: {model_wrapper}')
    print(f'[Rank {rank}] [TrainModel] Model current epoch: {model_wrapper.current_epoch}')

    if distributed:
        dist.barrier()

    # Step 4: Initialize the Model Trainer
    print(f'[Rank {rank}] [TrainModel] Initializing dataset')
    train_loader, evaluators = initialize_loaders(
        args, model_wrapper, args.model, distributed=distributed, rank=rank, world_size=world_size
    )

    if distributed:
        dist.barrier()

    # Step 5: Train the Model
    print(f'[Rank {rank}] [TrainModel] Starting training process')
    train_model(args, model_wrapper, train_loader, evaluators, args.save_path, distributed=distributed, rank=rank, world_size=world_size)

    # Step 6: Evaluate the Model using the clean and backdoor test loaders
    print(f'[Rank {rank}] [TrainModel] Evaluating model on clean test loader')
    test_model(args, model_wrapper, evaluators, args.save_path, is_test=True, distributed=distributed, rank=rank, world_size=world_size)

    # Step 6: Cleanup
    if distributed:
        dist.destroy_process_group()

if __name__ == "__main__":
    main()
