import argparse
import sys
import os
import torch
import torch.distributed as dist
import datetime

# 1) The directory this file lives in:
here = os.path.dirname(__file__)  
project_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir))
sys.path.append(project_root)

from BackdoorObjectDetection.defense.utils.loader import get_defense_loader_model

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

def write_log(metrics, save_path, file_name, epoch):

    # Check if save_path + file_name exists, if not create it
    if not os.path.exists(os.path.join(save_path, file_name)):
        
        # Create a file with save_path + file_name
        # Write the header of a CSV file
        with open(os.path.join(save_path, file_name), 'w') as f:
            f.write("epoch,")
            f.write(",".join(metrics.keys()) + "\n")

    # Append the metrics to the file
    with open(os.path.join(save_path, file_name), 'a') as f:
        f.write(f"{epoch},")
        f.write(",".join([str(metrics[key]) for key in metrics.keys()]) + "\n")

def test_model(args, model_wrapper, evaluator, save_path, is_test=True, distributed=False, rank=0, world_size=1):
    
    print(f'[Rank {rank}] [TrainModel] Starting model evaluation')

    if is_test:
        type_str = 'test'
    else:
        type_str = 'val'

    # Run all the evaluators of type 'test'
    for eval in evaluator:
        if eval['type'] == type_str:
            print(f'[Rank {rank}] [TrainModel] Evaluating on {eval["name"]} test loader')

            eval_save_path = os.path.join(save_path, eval['name'])
            metrics = eval['evaluator'].evaluate(eval['loader'], eval['loader'].dataset.bbox_return_format, eval_save_path, model_wrapper.current_epoch)

            # Save the metrics 
            if rank == 0:
                write_log(metrics, eval_save_path, f"{type_str}_metrics.csv", model_wrapper.current_epoch + 1)

                # Print the output of the metrics
                print(f'[Rank {rank}] [TrainModel] {eval["name"]} Test Metrics: {metrics}')

        if distributed:
            torch.cuda.empty_cache()
            dist.barrier()

def train_model(args, model_wrapper, train_loader, evaluator, save_path, distributed=False, rank=0, world_size=1):
    
    for epoch in range(model_wrapper.current_epoch, model_wrapper.epochs):

        print(f'[Rank {rank}] [TrainModel] Starting epoch {epoch + 1}/{model_wrapper.epochs}')

        if distributed:
            train_loader.sampler.set_epoch(epoch)
            torch.cuda.empty_cache()

        # Train the model
        losses = model_wrapper.train_one_epoch(train_loader, epoch)

        # Write the losses to disk
        if rank == 0:
            write_log(losses, save_path, "losses.csv", epoch + 1)
            print(f'[Rank {rank}] [TrainModel] Epoch {epoch + 1} losses: {losses}')

        if distributed:
            dist.barrier()

        if epoch % args.evaluate_every == 0 or epoch == model_wrapper.epochs - 1:
            test_model(args, model_wrapper, evaluator, save_path, is_test=True, distributed=distributed, rank=rank, world_size=world_size)

        if distributed:
            dist.barrier()


    if rank == 0:
        model_wrapper.__save_model__()

    if distributed:
        dist.barrier()

def parse_args(distributed=False, rank=0, world_size=1):
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="BD Model Training")

    # Arguments for the dataset
    parser.add_argument("--test_attack", required=True, type=str, help="Attack name for testing (e.g., 'rma', 'oda_targeted', 'oda_untargeted')")
    parser.add_argument("--trigger_position", required=True, type=str, help="Trigger position (e.g., 'random', 'center', 'high', 'low')")
    parser.add_argument("--trigger_type", required=True, type=str, help="Trigger type (e.g., 'square')")
    
    parser.add_argument("--dataset", required=True, type=str, help="Dataset name")
    parser.add_argument("--bd_base_path", required=True, type=str, help="Base path for dataset")
    parser.add_argument("--batch_size", default=16, type=int, help="Batch size total (will be divided by world_size if distributed)")
    parser.add_argument("--num_workers", default=4, type=int, help="Number of data loader workers")

    # Arguments for the model
    parser.add_argument("--model", required=True, type=str, help="Model name")
    parser.add_argument("--model_config_path", required=True, type=str, help="Model config path")

    # Arguments for FT training
    parser.add_argument("--record_path", required=True, type=str, help="Record path")
    parser.add_argument("--save_dir", required=True, type=str, help="Directory to save the model checkpoints and logs")

    parser.add_argument("--training_epochs", default=100, type=int, help="Number of training batches per epoch (default: 1000)")
    parser.add_argument("--num_training_samples", default=100, type=int, help="Number of training samples (default: 100, only used if training_batches is not set)")
    parser.add_argument("--evaluate_every", default=5, type=int, help="Evaluate the model every N epochs (default: 1)")
    parser.add_argument("--random_seed", default=42, type=int, help="Random seed for reproducibility (default: 42)")

    args = parser.parse_args()

    # Set args.data_attack == 'baseline' because we want clean training data
    args.data_attack = 'baseline'
    args.multi_position = None

    # Check that bd_dataset_path exists
    if not os.path.exists(args.bd_base_path):
        raise FileNotFoundError(f"[Rank {rank}] [TrainModel] The specified bd_base_path does not exist: {args.bd_base_path}")

    # Create a save_path that is record_path/args.save_dir
    args.save_path = os.path.join(args.record_path, args.save_dir, 'ft', f'num_samples_{args.num_training_samples}', f'random_seed_{args.random_seed}')
    print(f'[Rank {rank}] [TrainModel] Creating save path: {args.save_path}')
    if not os.path.exists(args.save_path):

        if rank == 0:
            os.makedirs(args.save_path)

        if distributed:
            dist.barrier()

    else:
        # Check if checkpoint.pth exists in the save_path
        checkpoint_path = os.path.join(args.save_path, 'checkpoint.pth')
        print(f"[Rank {rank}] [TrainModel] Checkpoint path: {checkpoint_path}")
        if os.path.exists(checkpoint_path):
            raise FileExistsError(f"[Rank {rank}] [TrainModel] The specified save_path already exists and contains a checkpoint.pth file. Please choose a different save_dir or delete the existing directory: {args.save_path}")
        else:
            print(f"[Rank {rank}] [TrainModel] Warning: The specified save_path already exists but does not contain a checkpoint.pth file. Continuing and overwriting existing files in: {args.save_path}")

            if rank == 0:
                # Remove all files and subdirectories in the save_path
                for root, dirs, files in os.walk(args.save_path, topdown=False):
                    for name in files:
                        os.remove(os.path.join(root, name))
                        print(f"[Rank {rank}] [TrainModel] Warning: Existing file in save_path that would be removed: {os.path.join(root, name)}")
                    for name in dirs:
                        # Even if the directory is not empty, remove it
                        os.rmdir(os.path.join(root, name))
                        print(f"[Rank {rank}] [TrainModel] Warning: Existing directory in save_path that would be removed: {os.path.join(root, name)}")

            if distributed:
                dist.barrier()

        if rank == 0:
            # Save the args to a file in the save_path
            args_file = os.path.join(args.save_path, "args.txt")
            with open(args_file, 'w') as f:
                for key, value in vars(args).items():
                    f.write(f"{key}: {value}\n")

        if distributed:
            dist.barrier()

    return args

def main():

    using_torchrun = all(
        var in os.environ for var in ("RANK", "WORLD_SIZE", "LOCAL_RANK")
    )

    if using_torchrun:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        distributed = world_size > 1
    else:
        # Fallback for plain python execution (debug / CPU run)
        rank = local_rank = 0
        world_size = 1
        distributed = False

    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device(local_rank)
    else:
        device = torch.device("cpu")

    if distributed:
        # torchrun already sets MASTER_ADDR / MASTER_PORT, so just use env://
        dist.init_process_group(backend="nccl", init_method="env://")

    print(f'[Rank {rank}] [TrainModel] Distributed training: {distributed}, World size: {world_size}, Local rank: {local_rank}')
    print(f'[Rank {rank}] [TrainModel] Device set to: {device}')

    if distributed:
        print(f'[Rank {rank}] [TrainModel] Initializing process group for distributed training')
        dist.barrier()
    
    # Step 3: parse your own arguments
    args = parse_args(distributed=distributed, rank=rank, world_size=world_size)

    if distributed:
        dist.barrier()

    train_loader, evaluators, model_wrapper = get_defense_loader_model(
        args, device, distributed=distributed, rank=rank, local_rank=local_rank, world_size=world_size
    )

    # Step 5: Train the Model
    print(f'[Rank {rank}] [TrainModel] Starting training process')
    train_model(args, model_wrapper, train_loader, evaluators, args.save_path, distributed=distributed, rank=rank, world_size=world_size)

    # Step 6: Cleanup
    if distributed:
        dist.destroy_process_group()

if __name__ == "__main__":
    main()
