from ray import tune
from argparse import Namespace
import logging
import math
import random
import sys
import os
import numpy as np
import torch
from fairseq import (
    checkpoint_utils,
    distributed_utils,
    options,
    quantization_utils,
    tasks,
    utils,
)
from fairseq.data import iterators
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from fairseq.trainer import Trainer


@metrics.aggregate("train")
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)

    trainer.begin_epoch(epoch_itr.epoch)
    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    for i, samples in enumerate(itr):
        log_output = trainer.train_step(samples)
        if log_output is None:  # OOM, overflow, ...
            continue

    valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
    print(valid_losses)
    return valid_losses, should_stop


def get_training_stats(stats):
    stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
    return stats


def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in itr:  # progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        # progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses


def get_valid_stats(args, trainer, stats):
    stats["num_updates"] = trainer.get_num_updates()
    if hasattr(checkpoint_utils.save_checkpoint, "best"):
        key = "best_{0}".format(args.best_checkpoint_metric)
        best_function = max if args.maximize_best_checkpoint_metric else min
        stats[key] = best_function(
            checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
        )
    return stats


cmd_args = dict(
    activation_dropout=0.0,
    activation_fn="gelu",
    adam_betas="(0.9, 0.98)",
    adam_eps=1e-06,
    add_prev_output_tokens=False,
    all_gather_list_size=16384,
    arch="roberta_base",
    attention_dropout=0.1,
    best_checkpoint_metric="accuracy",
    bf16=False,
    bpe=None,
    broadcast_buffers=False,
    bucket_cap_mb=25,
    checkpoint_suffix="",
    classification_head_name="sentence_classification_head",
    clip_norm=0.0,
    cpu=False,
    criterion="sentence_prediction",
    curriculum=0,
    data="RTE-bin/",
    data_buffer_size=10,
    dataset_impl=None,
    ddp_backend="c10d",
    device_id=0,
    disable_validation=False,
    distributed_backend="nccl",
    distributed_init_method=None,
    distributed_no_spawn=False,
    distributed_port=-1,
    distributed_rank=0,
    distributed_world_size=1,
    distributed_wrapper="DDP",
    dropout=0.1,
    empty_cache_freq=0,
    encoder_attention_heads=12,
    encoder_embed_dim=768,
    encoder_ffn_embed_dim=3072,
    encoder_layerdrop=0,
    encoder_layers=12,
    encoder_layers_to_keep=None,
    end_learning_rate=0.0,
    fast_stat_sync=False,
    find_unused_parameters=True,
    fix_batches_to_gpus=False,
    fixed_validation_seed=None,
    force_anneal=None,
    fp16=True,
    fp16_init_scale=4,
    fp16_no_flatten_grads=False,
    fp16_scale_tolerance=0.0,
    fp16_scale_window=128,
    init_token=0,
    keep_best_checkpoints=-1,
    keep_interval_updates=-1,
    keep_last_epochs=-1,
    localsgd_frequency=3,
    log_format=None,
    log_interval=100,
    lr=[2e-05],
    lr_scheduler="polynomial_decay",
    max_epoch=10,
    max_positions=512,
    max_sentences=2,
    max_sentences_valid=2,
    max_tokens=4400,
    max_tokens_valid=4400,
    max_update=0,
    maximize_best_checkpoint_metric=True,
    memory_efficient_bf16=False,
    memory_efficient_fp16=False,
    min_loss_scale=0.0001,
    min_lr=-1,
    model_parallel_size=1,
    no_epoch_checkpoints=False,
    no_last_checkpoints=False,
    no_progress_bar=False,
    no_save=False,
    no_save_optimizer_state=False,
    no_seed_provided=True,
    no_shuffle=False,
    nprocs_per_node=1,
    num_classes=2,
    num_workers=1,
    optimizer="adam",
    optimizer_overrides="{}",
    patience=-1,
    pooler_activation_fn="tanh",
    pooler_dropout=0.0,
    power=1.0,
    profile=False,
    quant_noise_pq=0,
    quant_noise_pq_block_size=8,
    quant_noise_scalar=0,
    quantization_config_path=None,
    regression_target=False,
    required_batch_size_multiple=1,
    reset_dataloader=True,
    reset_lr_scheduler=False,
    reset_meters=True,
    reset_optimizer=True,
    restore_file="./checkpoints/roberta.base/model.pt",
    save_dir="./checkpoints/original/rte/",
    save_interval=1,
    save_interval_updates=0,
    seed=1,
    sentence_avg=False,
    separator_token=2,
    shorten_data_split_list="",
    shorten_method="none",
    skip_invalid_size_inputs_valid_test=False,
    slowmo_algorithm="LocalSGD",
    slowmo_momentum=None,
    stop_time_hours=0,
    task="sentence_prediction",
    tensorboard_logdir="",
    threshold_loss_scale=1.0,
    tokenizer=None,
    tokens_per_sample=512,
    total_num_update=2036,
    tpu=False,
    train_subset="train",
    untie_weights_roberta=False,
    update_freq=[8],
    use_bmuf=False,
    use_old_adam=False,
    user_dir=None,
    valid_subset="valid",
    validate_interval=1,
    warmup_updates=122,
    weight_decay=0.1,
)


class GlueTask(tune.Trainable):
    def _setup(self, config):
        print(config)
        self.config = config
        # update the optimizer args from config
        opt_name = config["args"].optimizer
        if opt_name.startswith("Adam") or opt_name.startswith("RAdam"):
            self.config["beta1"] = 1.0 - self.config["new_beta1"]
            self.config["beta2"] = 1.0 - self.config["new_beta2"]
            cmd_args["optimizer"] = "adam" if opt_name.startswith("Adam") else "radam"
            cmd_args["adam_betas"] = f"({self.config['beta1']},{self.config['beta2']})"
            cmd_args["adam_eps"] = float(self.config["eps"])
            cmd_args["lr"] = [self.config["lr"]]
        elif opt_name.startswith("SGD") or opt_name.startswith("LARS"):
            cmd_args["optimizer"] = "sgd" if opt_name.startswith("SGD") else "lars"
            cmd_args["momentum"] = self.config["momentum"]
            cmd_args["lr"] = [self.config["lr"]]
            if opt_name.startswith("LARS"):
                cmd_args["fp16_no_flatten_grads"] = True

        elif opt_name.startswith("Yogi"):
            self.config["beta1"] = 1.0 - self.config["new_beta1"]
            self.config["beta2"] = 1.0 - self.config["new_beta2"]
            cmd_args["optimizer"] = "yogi"
            cmd_args["adam_betas"] = f"({self.config['beta1']},{self.config['beta2']})"
            cmd_args["adam_eps"] = float(self.config["eps"])
            cmd_args["lr"] = [self.config["lr"]]
            cmd_args["initial_accumulator"] = float(self.config["initial_accumulator"])
        elif opt_name.startswith("Lookahead"):
            self.config["beta1"] = 1.0 - self.config["new_beta1"]
            self.config["beta2"] = 1.0 - self.config["new_beta2"]
            cmd_args["optimizer"] = "lookahead"
            cmd_args["adam_betas"] = f"({self.config['beta1']},{self.config['beta2']})"
            cmd_args["adam_eps"] = float(self.config["eps"])
            cmd_args["lr"] = [self.config["lr"]]
            cmd_args["k"] = self.config["k"]
            cmd_args["alpha"] = self.config["alpha"]
        elif opt_name.startswith("LAMB"):
            self.config["beta1"] = 1.0 - self.config["new_beta1"]
            self.config["beta2"] = 1.0 - self.config["new_beta2"]
            cmd_args["optimizer"] = "lamb"
            cmd_args["lamb_betas"] = f"({self.config['beta1']},{self.config['beta2']})"
            cmd_args["lamb_eps"] = float(self.config["eps"])
            cmd_args["lr"] = [self.config["lr"]]
            cmd_args["fp16_no_flatten_grads"] = True
        else:
            raise ValueError
        cmd_args["restore_file"] = self.config["restore_file"]
        cmd_args["data"] = self.config["data_dir"]
        cmd_args["number_classes"] = self.config["num_classes"]
        cmd_args["max_sentences"] = self.config["max_sentences"]
        cmd_args["total_num_update"] = self.config["total_num_update"]
        cmd_args["warmup_updates"] = self.config["warmup_updates"]

        args = Namespace(**cmd_args)
        # cmd_args["gpt2_encoder_json"] = self.config["gpt2_encoder_json"]
        # cmd_args["gpt2_vocab_bpe"] = self.config["gpt2_vocab_bpe"]
        # convert dictionary to namedtuple
        np.random.seed(args.seed)
        utils.set_torch_seed(args.seed)
        task = tasks.setup_task(args)

        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for valid_sub_split in args.valid_subset.split(","):
            task.load_dataset(valid_sub_split, combine=False, epoch=1)

        # Build model and criterion
        model = task.build_model(args)
        criterion = task.build_criterion(args)
        trainer = Trainer(args, task, model, criterion, None)
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
        # Train until the learning rate gets too small
        max_epoch = args.max_epoch or math.inf
        self.epoch_itr = epoch_itr
        self.args = args
        self.trainer = trainer
        self.task = task
        self.model = model

    def _train(self):
        # train for one epoch
        valid_losses, _ = train(self.args, self.trainer, self.task, self.epoch_itr)

        # only use first validation loss to update the learning rate
        lr = self.trainer.lr_step(self.epoch_itr.epoch, valid_losses[0])
        self.epoch_itr = self.trainer.get_train_iterator(
            self.epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=self.task.has_sharded_data("train"),
        )
        self.valid_loss = valid_losses[0]
        return {"mean_accuracy": valid_losses[0], "early_stop": False}

    def _save(self, checkpoint_dir):
        self.args.save_dir = checkpoint_dir
        # checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        # torch.save(self.model.state_dict(), checkpoint_path)
        checkpoint_utils.save_checkpoint(
            self.args, self.trainer, self.epoch_itr, self.valid_loss
        )
        # checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_last.pt")
        return checkpoint_dir

    def _restore(self, checkpoint_dir):
        self.args.save_dir = checkpoint_dir
        # self.model.load_state_dict(torch.load(checkpoint_path))
        """
        if not os.path.isfile(checkpoint_path):
            print(checkpoint_path, "does not exist")
            exit(1)
        """
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
            self.args, self.trainer
        )
        self.epoch_itr = epoch_itr


if __name__ == "__main__":
    config = {
        "beta1": 0.9,
        "beta2": 0.9,
        "lr": 1.0e-5,
        "eps": 1.0e-6,
        "restore_file": "./roberta.base/model.pt",
        "dev_file": "./glue_data/RTE/dev.tsv",
        "data_dir": "RTE-bin/",
        "gpt2_encoder_json": "encoder.json",
        "gpt2_vocab_bpe": "vocab.bpe",
    }
    glue_task = GlueTask()
    glue_task._setup(config)
    glue_task._train()
