import argparse
import collections
import json
import os
import re
from pathlib import Path
from open_clip.src.open_clip_train.distributed import world_info_from_env
from open_clip.src.open_clip_train.main import main

os.environ["WANDB__SERVICE_WAIT"] = "300"  # Solving WANDB connection issues

def prepare_filename(filename):
    filename = str(filename)
    if filename.startswith("s3://"):
        return f"pipe:aws s3 cp {filename} -"
    return filename


def split_filename(pattern, filename):
    filename = str(filename)
    pattern_match = pattern.search(filename)
    pos = pattern_match.start()
    return filename[:pos], filename[pos:]


def get_input_shards(data_dir, weights):
    # Handle multiple directories
    if "::" in str(data_dir):
        split_data_dir = str(data_dir).split("::")
        data_dirs = [path_or_cloudpath(subdir) for subdir in split_data_dir]
        if weights is None:
            split_weights = [None for _ in split_data_dir]
        else:
            split_weights = weights.split("::")
            assert len(split_weights) == len(split_data_dir)

        input_strs_and_weights = [
            get_input_shards(subdir, weight)
            for (subdir, weight) in zip(data_dirs, split_weights)
        ]

        input_strs, input_weights = zip(*input_strs_and_weights)
        input_strs = "::".join(input_strs)
        if weights is not None:
            weights = "::".join(input_weights)
        return input_strs, weights

    # Handle raw shards
    if data_dir.suffix == ".tar":
        return prepare_filename(data_dir), weights

    # Handle folders
    files_or_subdirs = list(data_dir.iterdir())
    data_str_components = []
    prefix_map = collections.defaultdict(list)
    pattern = re.compile("\d+$")  # Sequence of digits at the end of the string
    count_tars = 0
    for file_or_subdir in files_or_subdirs:
        if file_or_subdir.suffix == ".tar":
            shard = file_or_subdir.with_suffix("")
            prefix, suffix = split_filename(pattern, shard)
            prefix_map[prefix].append(suffix)
            count_tars += 1
        elif file_or_subdir.is_dir():
            # If the folder is generated by the resharder, the metadata file contains how many shards there are.
            metadata_file = file_or_subdir / "meta.json"
            if metadata_file.exists():
                with open(metadata_file, "r") as f:
                    metadata = json.load(f)
                shard_count = metadata["output_shard_count"]
                shard_format = metadata["output_shard_format"]
                first_shard = shard_format.format(0).replace(".tar", "")
                last_shard = shard_format.format(shard_count - 1).replace(".tar", "")
                filename = f"{{{first_shard}..{last_shard}}}.tar"
                subfolder_str = prepare_filename(file_or_subdir / filename)
                data_str_components.append(subfolder_str)
            else:
                sub_data_strs, _ = get_input_shards(file_or_subdir, weights)
                data_str_components.extend(sub_data_strs.split("::"))

    for prefix in sorted(list(prefix_map.keys())):
        last_tar = max([int(suffix) for suffix in prefix_map[prefix]])
        number_of_zeros = len(prefix_map[prefix][0])
        filename = f"{{{0:0{number_of_zeros}d}..{last_tar:0{number_of_zeros}d}}}.tar"
        filename = prepare_filename(prefix + filename)
        data_str_components.append(filename)
    data_str = "::".join(data_str_components)
    if weights is not None:
        weights = "::".join([weights for _ in data_str_components])
    return data_str, weights


def path_or_cloudpath(s):
    try:
        from cloudpathlib import CloudPath
        if re.match(r"^\w+://", s):
            return CloudPath(s)
    except ImportError:
        pass
    return Path(s)


def parse_args():
    parser = argparse.ArgumentParser()

    # General configuration params
    parser.add_argument(
        "--exp_name", type=str, default=None, help="Name of the experiment for logging."
    )
    parser.add_argument(
        "--wandb_project_name",
        type=str,
        default="flyt",
        help="Name of the project if logging with wandb.",
    )
    parser.add_argument(
        "--report_to_wandb",
        default=False,
        action="store_true",
        help="If True, report to wandb.",
    )
    parser.add_argument(
        "--workers", type=int, default=4, help="Number of workers for open_clip."
    )
    parser.add_argument(
        "--num_checkpoints",
        type=int,
        default=5,
        help="Number of times we save checkpoints during training.",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed.")
    parser.add_argument(
        "--accum_freq",
        type=int,
        default=1,
        help="Update the model every --acum-freq steps.",
    )
    parser.add_argument(
        "--log_every_n_steps",
        type=int,
        default=100,
        help="Log every n steps to tensorboard/console/wandb.",
    )
    parser.add_argument(
        "--resume",
        default="latest",
        type=str,
        help="Path to checkpoint to resume from (default: latest checkpoint in the training directory).",
    )
    parser.add_argument(
        "--precision",
        type=str,
        choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
        default="amp",
        help="Floating point precision.",
    )
    parser.add_argument("--save_frequency", type=int, default=0)
    parser.add_argument(
        "--output_dir",
        type=path_or_cloudpath,
        required=True,
        help="Path to directory where outputs will be stored.",
    )
    
    # General training params
    parser.add_argument(
        "--model",
        type=str,
        default="ViT-B-32",
        help='Name of the vision backbone to use',
    )
    parser.add_argument(
        "--upstream_data_dir",
        type=path_or_cloudpath,
        required=True,
        help='Path to directory where the upstream data is stored. Multiple paths can be used, separated by "::".',
    )
    parser.add_argument(
        "--downstream_data_dir",
        type=path_or_cloudpath,
        required=True,
        help="Path to directory where the downstream data is stored. See flyt/data.py for usage.",
    )
    parser.add_argument(
        "--data_weights",
        type=str,
        default=None,
        help=(
            "When using multiple data sources with webdataset and sampling with replacement, which weight to use for sampling the different data sources. "
            "Similar to --upstream_data_dir, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) "
            "By default, datapoints are sampled uniformly regardless of the dataset sizes."
        ),
    )
    parser.add_argument(
        "--downstream_task_names",
        type=path_or_cloudpath,
        default='imagenet',
        help='List of task names, separated by "::". See flyt.data.UpAndDownstreamDataSet for usage.',
    )
    parser.add_argument(
        "--downstream_data_weights",
        type=str,
        default=None,
        help=(
            "When using multiple downstream data sources with webdataset and sampling with replacement, which weight to use for sampling the different data sources. "
            "Similar to --downstream_task_names, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) "
            "By default, datapoints are sampled uniformly regardless of the dataset sizes."
        ),
    )
    parser.add_argument(
        "--upstream_batch_size",
        type=int,
        default=4096
    )
    parser.add_argument(
        "--downstream_batch_size",
        type=int,
        default=3072
    )
    parser.add_argument(
        "--reference_learning_rate",
        default=5e-5,
        type=float,
        help='Reference model learning rate'
    )
    parser.add_argument(
        "--scoring_learning_rate",
        default=1e-3,
        type=float,
        help='Scoring model learning rate'
    )
    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument(
        "--reference_num_samples",
        default=None,
        type=int,
        help='Total number of reference samples. Mutually exclusive with n_iterations.'
    )
    group.add_argument(
        "--n_iterations",
        default=5000,
        type=int,
        help='Number of iterations. Mutually exclusive with reference_num_samples.'
    )
    parser.add_argument(
        "--warmup",
        default=100,
        type=int,
        help='Number of warmup steps'
    )
    parser.add_argument(
        "--scaler_scale",
        type=float,
        default=1024,
        help="Grad scaler init scale"
    )
    parser.add_argument(
        "--reference_pretrained",
        default='',
        type=str,
        help="For the reference model, use a pretrained CLIP model weights with the specified tag or file path.",
    )
    parser.add_argument(
        "--scoring_pretrained",
        default='',
        type=str,
        help="For the scoring model base, use a pretrained CLIP model weights with the specified tag or file path.",
    )
    parser.add_argument(
        "--full_scoring_pretrained",
        default=None,
        type=str,
        help="For the scoring model, use a pretrained model weights. Unlike scoring_pretrained, this includes the head.",
    )
    parser.add_argument(
        "--update_reference_model",
        default=False,
        action="store_true",
        help='Update the reference model during the run.'
    )
    parser.add_argument(
        "--downstream_logit_scale",
        type=float,
        default=None,
        help="Logit scale for downstream data.")
    parser.add_argument(
        "--downstream_clip_loss",
        action="store_true",
        default=False,
        help="Use downstream CLIP loss instead of CE.",
    )
    parser.add_argument(
        "--log_pre_update",
        action="store_true",
        default=False,
        help="Log the loss before updating the reference model. Provides information, but requires an additional forward pass.",
    )
    parser.add_argument(
        "--datacomp_eval_dir",
        type=str,
        default=None,
        help="DataComp evalsets directory. Required for the classnames.txt and zeroshot_classification_templates.txt files.",
    )
    parser.add_argument(
        "--dataset_weighted",
        action="store_true",
        default=False,
        help="Use balanced downstream dataset weighting when using multiple downstream tasks.",
    )
    
    # FLYT specific
    parser.add_argument(
        "--train_full_scoring",
        default=False,
        action="store_true",
        help='Train the entire scoring model (do not freeze the base)'
    )
    parser.add_argument(
        "--n_scoring_layers",
        default=1,
        type=int,
        help='How many GLU layers in the scoring model.'
    )
    parser.add_argument(
        "--hidden_dim_mod", type=float, default=1.0,
        help="The hidden dimension of the GLU layer will be hidden_dim_mod times the input dimension of the GLU layer."
    )
    parser.add_argument(
        "--cos_sim_init",
        default=False,
        action="store_true",
        help='Initialize the GLU head to (approximately) compute the cosine similarity'
    )

    # M-FLYT specific
    parser.add_argument(
        "--parquet_dir",
        default=None,
        type=str,
        help="Parquet directory containing input scores for M-FLYT",
    )
    parser.add_argument(
        '--parquet_fields', 
        type=str, 
        nargs='+', 
        default=None,
        help='Input fields to load from parquet_dir.'
    )
    
    args = parser.parse_args()
    return args


def train_flyt():
    args = parse_args()
    _, rank, world_size = world_info_from_env()
    if rank == 0:
        print(f"World size is {world_size}.")
    
    upstream_data, weights = get_input_shards(args.upstream_data_dir, args.data_weights)

    per_gpu_upstream_batch_size = args.upstream_batch_size // (world_size * args.accum_freq)
    per_gpu_downstream_batch_size = args.downstream_batch_size // (world_size * args.accum_freq)
    if args.reference_num_samples is None:
        args.reference_num_samples  = args.n_iterations * args.upstream_batch_size

    main_args = [
        "--save-frequency",
        f"{args.save_frequency}",
        "--ddp-static-graph",
        "--gather-with-grad",
        "--train-data",
        f"{upstream_data}",
        "--downstream_data_dir",
        f"{args.downstream_data_dir}",
        "--train-num-samples",
        f"{args.reference_num_samples // args.num_checkpoints}",
        "--warmup",
        f"{args.warmup}",
        "--dataset-type",
        "webdataset",
        "--precision",
        f"{args.precision}",
        "--workers",
        f"{args.workers}",
        "--model",
        f"{args.model}",
        "--batch-size",
        f"{per_gpu_upstream_batch_size}",
        "--downstream_batch_size",
        f"{per_gpu_downstream_batch_size}",
        "--epochs",
        f"{args.num_checkpoints}",
        "--lr",
        f"{args.reference_learning_rate}",
        "--scoring_lr",
        f"{args.scoring_learning_rate}",
        "--logs",
        f"{args.output_dir}",
        "--name",
        f"{args.exp_name}",
        "--seed",
        f"{args.seed}",
        "--accum-freq",
        f"{args.accum_freq}",
        "--log-every-n-steps",
        f"{args.log_every_n_steps}",
        "--save-most-recent",
        "--resume",
        f"{args.resume}",
        "--n_scoring_layers",
        f"{args.n_scoring_layers}",
        "--hidden_dim_mod",
        f"{args.hidden_dim_mod}",
        "--scaler_scale",
        f"{args.scaler_scale}",
        "--dataset-resampled",
        "--downstream_task_names",
        f"{args.downstream_task_names}", 
        "--flyt"
    ]
    if args.report_to_wandb:
        main_args.extend([
            "--report-to",
            "wandb",
            "--wandb-project-name",
            f"{args.wandb_project_name}",
        ])
    
    if weights is not None:
        main_args.extend(["--train-data-upsampling-factors", weights])
    if args.parquet_dir is not None:
        main_args.extend(['--parquet_dir', f'{args.parquet_dir}'])
    if args.downstream_data_weights is not None:
        main_args.extend(["--downstream-data-upsampling-factors", f'{args.downstream_data_weights}'])
    if args.downstream_logit_scale is not None:
        main_args.extend(['--downstream_logit_scale', f'{args.downstream_logit_scale}'])
    if args.parquet_fields is not None:
        main_args.append("--parquet_fields")
        main_args.extend(args.parquet_fields)
    if args.reference_pretrained:
        main_args.extend(['--pretrained', f'{args.reference_pretrained}'])
    if args.scoring_pretrained:
        main_args.extend(['--scoring_pretrained', f'{args.scoring_pretrained}'])
    if args.full_scoring_pretrained:
        main_args.extend(['--full_scoring_pretrained', f'{args.full_scoring_pretrained}'])
    if args.datacomp_eval_dir:
        main_args.extend(['--datacomp_eval_dir', f'{args.datacomp_eval_dir}'])

    if args.update_reference_model:
        main_args.append("--update_reference_model")
    if args.train_full_scoring:
        main_args.append("--train_full_scoring")
    if args.cos_sim_init:
        main_args.append("--cos_sim_init")
    if args.downstream_clip_loss:
        main_args.append("--downstream_clip_loss")
    if args.log_pre_update:
        main_args.append("--log_pre_update")
    if args.dataset_weighted:
        main_args.append("--dataset_weighted")
    print(main_args)
    success = main(main_args)

    if rank == 0:
        if success == -1:
            print("Error running training. Exiting.")

        final_checkpoint = args.output_dir / args.exp_name / "checkpoints" / f"epoch_latest.pt"
        assert (
            final_checkpoint.exists()
        ), f"Did not find the checkpoint at {final_checkpoint}"


if __name__ == "__main__":
    train_flyt()
