import os
import sys
import torch
import logging
import argparse
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from classification.trainer.jcgel_conv_cls import JCGConv_Cls_Trainer
from classification.trainer.jcgel_resnet_cls import JCGResNet_Cls_Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


import traceback, io, torch.distributed as dist

def ddp_try(fn):
    try:
        return fn()
    except Exception as e:
        buf = io.StringIO()
        traceback.print_exc(file=buf)
        msg = buf.getvalue()
        rank = dist.get_rank() if (dist.is_available() and dist.is_initialized()) else 0
        print(f"\n[Rank {rank}] Uncaught exception:\n{msg}", flush=True)

        if dist.is_available() and dist.is_initialized():
            try:
                dist.barrier()
            except Exception:
                pass
        raise

def cleanup_ddp():
    import torch.distributed as dist
    if dist.is_available() and dist.is_initialized():
        try:
            dist.barrier()
        except Exception:
            pass
        dist.destroy_process_group()

def build_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--device_idx",
        type=str,
        default="cuda:0",
        required=True,
        help="set GPU index, i.e. cuda:0,1,2 ...",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Avoid using CUDA when available"
    )
    parser.add_argument(
        "--n_gpu",
        type=int,
        default=0,
        required=False,
        help="number of available gpu",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )

    # set Task
    parser.add_argument(
        "--task",
        choices=["longtail", "bias", "cls"],
        default="cls",
        type=str,
        help="Whether using scheduler during training or not",
        required=True,
    )

    # DATASETS
    parser.add_argument(
        "--dataset",
        type=str,
        choices=[
            "color_rot_mnist",
            "cifar10",
            "cifar100",
            "eurosat",
            "imagenet",
            "flowers102",
            "stanfordcars",
            "food101",
            "caltech101",
            "oxfordiiitpet",
            "aircraft",
            "stl10",
        ],
        required=True,
        help="Choose Dataset",
    )


    # SET MODEL
    parser.add_argument(
        "--model_type",
        type=str,
        choices=[
            'cgeconv',
            'cgeresnet', # Color and Rotation Equiv.
        ],
        required=True,
        help="choose vae type",
    )

    # Set Training

    parser.add_argument(
        '--num_workers',
        default=4,
        type=int,
        help='number of workers for loading data'
    )


    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of training mini-batch size",
    )
    parser.add_argument(
        "--per_gpu_train_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of training mini-batch size for multi GPU training",
    )
    parser.add_argument(
        "--test_batch_size",
        type=int,
        default=128,
        required=False,
        help="Set number of evaluation mini-batch size",
    )
    parser.add_argument(
        "--num_epoch",
        type=int,
        default=60,
        required=False,
        help="Set number of epoch size",
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=0,
        required=False,
        help="Set number of epoch size",
    )
    parser.add_argument(
        "--save_steps",
        type=int,
        default=500,
        required=False,
        help="Save model checkpoint iteration interval",
    )
    parser.add_argument(
        "--logging_steps",
        type=int,
        # default=1000,
        required=False,
        help="Update tb_writer iteration interval",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        required=False,
        help="set seed",
    )


    parser.add_argument(
        "--lr_rate", default=1e-4, type=float, required=False, help="Set learning rate"
    )

    parser.add_argument(
        "--weight_decay",
        default=0.0,
        type=float,
        required=False,
        help="Set weight decay",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
    )
    parser.add_argument(
        "--num_sampling",
        type=int,
        default=1,
        required=False,
        help="Set samples for reparameterization trick",
    )



    # SET WANDB
    parser.add_argument(
        "--project_name",
        type=str,
        required=True,
        help="set project name for wiehgt and bias writer",
    )
    parser.add_argument(
        "--entity",
        type=str,
        required=True,
        help="set project name for wiehgt and bias writer",
    )


    parser.add_argument(
        "--latent_dim",
        type=int,
        default=10,
        required=False,
        help="set prior dimension z",
    )
    parser.add_argument(
        "--beta",
        type=float,
        default=1.0,
        required=False,
        help="Set hyper-parameter beta",
    )

    # MODEL TRAINING AND EVALUATION
    parser.add_argument("--do_train", action="store_true", help="Do training")
    parser.add_argument("--do_eval", action="store_true", help="Do evaluation")


    parser.add_argument("--c_rot",
                        type=int,
                        default=3,
                        required=False,)
    parser.add_argument("--g_rot",
                        type=int,
                        default=4,
                        required=False,)
    parser.add_argument("--n_flip",
                        type=int,
                        default=0,
                        required=False,)
    parser.add_argument("--temperature",
                        type=float,
                        default=0.01,
                        required=False,)
    parser.add_argument("--normalization", action="store_true")
    parser.add_argument("--soft", action="store_true")


    args = parser.parse_args()

    return args




def main():

    args = build_arguments()

    if "LOCAL_RANK" in os.environ:
        args.local_rank = int(os.environ["LOCAL_RANK"])

    is_ddp = args.local_rank != -1

    if is_ddp:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(args.local_rank)
        if dist.get_rank() == 0:
            print(f"Running in DDP mode. World size: {dist.get_world_size()}")
    else:
        if torch.cuda.is_available() and not args.no_cuda:
            print("Running in single-GPU mode.")
        else:
            print("Running in CPU mode.")


    if args.model_type == 'cgeresnet':
        trainer = JCGResNet_Cls_Trainer(args)
    else:
        trainer = JCGConv_Cls_Trainer(args)

    if is_ddp:
        ddp_try(trainer.run)
    else:
        trainer.run()

    return

if __name__ == "__main__":
    main()


















