import argparse
import os
from argparse import ArgumentParser
from pathlib import Path
import numpy as np
import pandas as pd
import torch

from models import NAMES_TO_CHECKPOINTS

try:
    import wandb
except ImportError:
    wandb = None


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("-d", "--dataset", required=True, help='dataset name.', nargs='+',
                        choices=['sst2', 'dbpedia', 'agnews', 'trec', 'seq_language'])
    parser.add_argument('-m', '--models', '--inference_models', default=list(NAMES_TO_CHECKPOINTS.keys()),
                        nargs='*',
                        help='Specify which models to run on instead of full NAMES_TO_CHECKPOINTS.',
                        choices=list(NAMES_TO_CHECKPOINTS.keys()),
                        )
    parser.add_argument("--seed", help='Seed for reproducibility.', type=int, default=[59], nargs='+')
    parser.add_argument("--num_shots", type=int, help='number of examples for ICL.', default=[0], nargs='+')
    # examples selection
    parser.add_argument("--examples_selection_method", required=True, nargs='+',
                        help="method for selecting examples for ICL.")
    parser.add_argument("--example_ids", type=int, nargs="+",
                        help="ids of the train samples to use as examples for ICL.")
    parser.add_argument("--examples_path",
                        help="specify path to .json file where the retrieved examples are stored.")
    # prediction methods
    parser.add_argument("--prediction_method", default=['direct'], nargs='+',
                        choices=["direct", "channel", "calibrate"],
                        help="Method of prediction on test inputs. "
                             "It is recommended to run Channel and Calibrate methods with setting labels_loss=True."
                        )
    parser.add_argument("--labels_loss", action='store_true',
                        help="Whether to calculate loss over the whole sequence or only on the label tokens.")
    # inference args
    parser.add_argument("--eval_batch_size", type=int, default=16,
                        help="Batch size for inference.")
    parser.add_argument("--precision", choices=['fp16', 'fp32'], default='fp16',
                        help='floating point precision for inference model.')
    # hf args
    parser.add_argument("--cache_dir", help="Path to huggingface cache")
    parser.add_argument("--local_files_only", action='store_true',
                        help="turn this on if you want to make sure that you do not download the same weights from HF "
                             "hub again to another path occasionally.")
    parser.add_argument("--num_templates", type=int, help='number of randomly generated templates.', default=10)
    parser.add_argument("--templates_path",
                        help="Path to a *.json file containing pre-determined set of templates.")
    parser.add_argument("--template_seed", type=int, default=59,
                        help='Seed for generating random templates.',
                        )
    # infrastructure args
    parser.add_argument("--save_dir", default=".", help="Where to save the results.")
    parser.add_argument("--use_wandb", default=True,
                        help="Write --no-use_wandb to disable WandB.")
    parser.add_argument("--wandb_entity", default=None)
    parser.add_argument("--wandb_project", default='ExamplesSelection')
    parser.add_argument("--device_map", default="auto")
    parser.add_argument("--force_rerun", type=bool, default=False)
    parser.add_argument("--method_boost_type", type=str, default='None')
    parser.add_argument("--peft_lr", type=float, default=0.0035)
    parser.add_argument("--peft_steps", type=int, default=50)
    parser.add_argument("--training_num_shots", type=int, default=-1)

    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_r", type=int, default=8)

    parser.add_argument("--num_virtual_tokens", type=int, default=8)

    parser.add_argument("--collate_fn_name", type=str, default='standard')

    parser.add_argument("--alpha", type=float, default=0.01)
    parser.add_argument("--epsilon", type=float, default=0.5)

    parser.add_argument("--normalization_epsilon", type=float, default=1e-8)
    parser.add_argument("--num_of_iter", type=int, default=20)
    parser.add_argument("--enhancement_batch_size", type=int, default=4)
    parser.add_argument("--number_of_seeds", type=int, default=3)
    parser.add_argument("--attack_regularization_value", type=float, default=0.0)
    parser.add_argument("--token_update_strategy", type=str, default='all_to_all')
    parser.add_argument("--enhancement_function", type=str, default='pgd_with_projection')
    parser.add_argument("--use_non_casual_mask", type=int, default=0)
    parser.add_argument("--update_tokens", type=str,
                        default='input')  # 'input', 'input_mask', 'output_mask', 'all_masks', 'input_and_all_masks'
    parser.add_argument("--hp_cross_validation", type=bool, default=False)
    parser.add_argument("--display_enhancment_loss", type=bool, default=True)
    parser.add_argument("--fp16", type=str, default=torch.float16)  # torch.float32, torch.float16
    parser.add_argument("--use_different_template_for_eval", type=bool, default=False)
    parser.add_argument("--ICL_loss_pattern", type=str, default='all_labels')
    parser.add_argument("--random_losses_num", type=int, default=1)
    parser.add_argument("--ICL_projection_epsilon_type", type=str, default='token_wise')
    parser.add_argument("--peft_batch_size", type=int, default=4)
    parser.add_argument("--remove_train_example_from_icl", type=int, default=0)
    parser.add_argument("--use_context_for_instruciton", type=int, default=0)
    parser.add_argument("--train_model", type=int, default=1)
    parser.add_argument("--hpt", type=bool, default=False)
    parser.add_argument("--peft_weighted_loss_type", type=str, default='none')
    parser.add_argument("--peft_weighted_loss_decay_factor", type=float, default=0.7)
    parser.add_argument("--use_train_for_eval", type=int, default=0)
    parser.add_argument("--context_seperator_after_length", type=int, default=1)
    parser.add_argument("--context_seperator_before_length", type=int, default=0)
    parser.add_argument("--peft_epochs", type=int, default=-1)
    parser.add_argument("--decay_projection", type=int, default=0)
    parser.add_argument("--decay_projection_base", type=float, default=0.01)
    parser.add_argument("--seq_langauge_len", type=int, default=10)
    parser.add_argument("--seq_langauge_each_seq_len", type=int, default=4)
    parser.add_argument("--seq_langauge_output_len", type=int, default=1)
    parser.add_argument("--seq_langauge_classes", type=int, default=20)
    parser.add_argument("--used_loss_tokens", type=str, default='answer')
    parser.add_argument("--force_same_masks", type=int, default=0)
    parser.add_argument("--use_learnable_format", type=int, default=0)
    parser.add_argument("--num_of_learnable_format", type=int, default=8)
    parser.add_argument("--format_epsilon", type=float, default=-1)
    parser.add_argument("--ICL_projection_epsilon_multiplier", type=float, default=1)
    parser.add_argument("--ICL_projection_format_epsilon_multiplier", type=float, default=1)
    parser.add_argument("--number_of_learnable_tokens", type=int, default=0)
    parser.add_argument("--add_instruction", type=int, default=0)
    parser.add_argument("--use_random_instruction", type=int, default=0)
    # use_random_instruction, num_of_learnable_format
    args = parser.parse_args()
    if args.method_boost_type in ['LORA']:
        args.fp16 = torch.float32

    if 'decay' in args.peft_weighted_loss_type:
        args.peft_weighted_loss_decay_factor = float(args.peft_weighted_loss_type.split('_')[-1])
    elif 'equal' in args.peft_weighted_loss_type:
        args.peft_weighted_loss_decay_factor = float(args.peft_weighted_loss_type.split('_')[-1])

    if args.peft_epochs != -1:
        args.peft_steps = -1

    if args.method_boost_type in ['ICL']:
        args.method_boost_type = 'prompt_tuning'
        args.train_model = False
        args.use_random_instruction = 0
        args.use_context_for_instruciton = 1
    elif args.method_boost_type in ['IPT']:
        args.number_of_learnable_tokens = 0
        args.add_instruction = 1
        args.use_context_for_instruciton = 1
        args.ICL_loss_pattern = 'last_label'
        args.peft_weighted_loss_type = 'none'
    elif args.method_boost_type in ['our_prompt_tuning']:
        args.use_context_for_instruciton = 1
        args.use_random_instruction = 0
    elif args.method_boost_type in ['prompt_tuning']:
        args.use_context_for_instruciton = 0
    elif args.method_boost_type in ['prefix_tuning']:
        args.use_context_for_instruciton = 0
        args.use_random_instruction = 0
    elif args.method_boost_type in ['LORA']:
        args.use_random_instruction = 0

    if args.used_loss_tokens in ['input_and_answer']:
        args.update_tokens = 'all_masks'

    if args.add_instruction:
        args.number_of_learnable_tokens = 0

    if args.format_epsilon == -1:
        args.format_epsilon = args.epsilon

    print(args)
    print('========')
    print(args.use_non_casual_mask)
    print('========')
    return args


def get_results_torch(save_dir, name="results"):
    res_path = Path(save_dir, name)

    if Path.exists(res_path):
        results = torch.load(res_path)
    else:
        results = {}
    return results


def get_results_pd(save_dir, force_rerun=False, name="results.csv"):
    res_path = Path(save_dir, name)
    if Path.exists(res_path) and force_rerun == False:
        results = pd.read_csv(res_path)
    else:
        results = pd.DataFrame(columns=["dataset", "model", "seed", "example_selection_method", "n_shots",
                                        "example_ids", "prediction_method", "batch_size", "precision",
                                        "template_seed",
                                        "template", "score",
                                        ])
    return results


def find_current_run(config: dict, results: pd.DataFrame) -> list:
    """for a given setup find existing runs (if any)"""
    results_values = results[list(config)]
    found_runs = results.loc[(results_values == pd.Series(config)).all(axis=1)]
    scores = found_runs["score"].tolist()

    return scores


def save_results_torch(res_obj, save_dir, name):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(res_obj, Path(save_dir, name))


def save_results_pd(res_df, run_config, template, score, name="results.csv", save_dir="."):
    os.makedirs(save_dir, exist_ok=True)
    run_config.update({"template": str(template), "score": score})
    res_df = pd.concat([res_df, pd.DataFrame([run_config])], ignore_index=True)
    res_df.to_csv(Path(save_dir, name), index=False)

    return res_df
