import os
import warnings
from dataclasses import dataclass, field
import torch
from datetime import timedelta
from transformers import HfArgumentParser
from transformers import set_seed
from accelerate import InitProcessGroupKwargs
from accelerate import Accelerator
from accelerate.utils import gather_object


from load import load_sft_dataset, load_pretrain_model_tokenizer
from preprocess import preprocess_dataset
from trainer import train
from predict import predict_token_train, predict_token_eval, predict_option_eval
from utils import format_run_name, save_tables, save_results, load_tables

process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=5400))
accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])

warnings.filterwarnings(
    "ignore", message="You passed a tokenizer with `padding_side` not equal to `right`"
)


@dataclass
class ScriptArguments:
    model_name: str = field(metadata={"help": "Name of the weak model to use"})
    dataset_name: str = field(metadata={"help": "Name of the dataset to use"})
    strong_model_name: str = field(metadata={"help": "Name of the strong model to use"})
    is_easy_to_hard: bool = field(
        default=False, metadata={"help": "Whether to use easy-to-hard sampling"}
    )
    adaboost_rounds: int = field(
        default=3, metadata={"help": "Number of AdaBoost rounds"}
    )
    num_epochs: int = field(default=10, metadata={"help": "Number of training epochs"})
    learning_rate: float = field(
        default=5e-5, metadata={"help": "Learning rate for training"}
    )
    train_batch_size: int = field(
        default=8, metadata={"help": "Batch size for training"}
    )
    pred_batch_size: int = field(
        default=8, metadata={"help": "Batch size for prediction"}
    )
    num_proc: int = field(
        default=4, metadata={"help": "Number of processes for data loading"}
    )
    model_max_length: int = field(
        default=512, metadata={"help": "Maximum length of the model"}
    )
    # Setups
    is_token_based_error: bool = field(
        default=True, metadata={"help": "Whether to use token-based error"}
    )
    is_weight_by_token: bool = field(
        default=True, metadata={"help": "Whether to weight by token"}
    )
    is_completion_only: bool = field(
        default=True, metadata={"help": "Whether to use completion only"}
    )
    # Samplings
    probability_bias: float = field(
        default=0.0, metadata={"help": "Additive bias for probability when sampling"}
    )
    token_prob_window_size: int = field(
        default=1,
        metadata={"help": "Sliding window size for smoothing token probability"},
    )
    # Combinations
    logits_top_k: int = field(
        default=16, metadata={"help": "Number of top values in logits to combine"}
    )
    is_combine_probs: bool = field(
        default=True, metadata={"help": "Whether to combine probabilities"}
    )
    is_top_k_pooling: bool = field(
        default=False, metadata={"help": "Whether to use top-k pooling"}
    )
    # Testing
    test_limit: int = field(
        default=-1, metadata={"help": "whether to run on a limited number of examples"}
    )
    # Seed
    seed: int = field(default=42, metadata={"help": "Random seed"})
    w2s_folder: str = field(
        default="/fs/cml-projects/E2H/w2s_hack", metadata={"help": "Path to the folder"}
    )
    temp_folder: str = field(
        default="/tmp/ensemw2s", metadata={"help": "Temporary folder path"}
    )
    grp_name: str = field(
        default="E2H", metadata={"help": "Name of the group of experiments"}
    )


def main(sargs):
    if accelerator.is_main_process:
        # Load dataset
        train_dataset, transfer_dataset, eval_dataset = load_sft_dataset(sargs)
        # tmp = train_dataset
        # train_dataset = transfer_dataset
        # transfer_dataset = tmp

        # Load pretrain model
        pretrain_model, tokenizer = load_pretrain_model_tokenizer(sargs, accelerator)

        # Preprocess dataset
        train_dataset = preprocess_dataset(
            train_dataset, tokenizer, sargs, is_train=True
        )
        eval_dataset = preprocess_dataset(
            eval_dataset, tokenizer, sargs, is_train=False
        )
        transfer_dataset = preprocess_dataset(
            transfer_dataset, tokenizer, sargs, is_train=False
        )

        # Evaluate pretrained model
        train_dataset = predict_token_train(
            pretrain_model, train_dataset, t=0, sargs=sargs
        )
        eval_dataset = predict_token_eval(
            pretrain_model, train_dataset, eval_dataset, t=0, sargs=sargs
        )
        eval_dataset = predict_option_eval(
            pretrain_model, train_dataset, eval_dataset, t=0, sargs=sargs
        )
        transfer_dataset = predict_option_eval(
            pretrain_model, train_dataset, transfer_dataset, t=0, sargs=sargs
        )

        del pretrain_model
        torch.cuda.empty_cache()

        save_tables(
            train_dataset,
            eval_dataset,
            transfer_dataset,
            sargs,
            sargs.temp_folder,
        )
    accelerator.wait_for_everyone()
    train_dataset, eval_dataset, transfer_dataset = load_tables(
        sargs, sargs.temp_folder
    )

    # Oracle strong model training
    strong_model = train(
        sargs.strong_model_name,
        transfer_dataset,
        eval_dataset,
        0,
        sargs,
        accelerator,
        mode="strong",
    )
    if accelerator.is_main_process:
        # Evaluate strong model
        eval_dataset = predict_token_eval(
            strong_model, None, eval_dataset, t=0, sargs=sargs, mode="strong"
        )
        eval_dataset = predict_option_eval(
            strong_model, None, eval_dataset, t=0, sargs=sargs, mode="strong"
        )

        save_tables(
            train_dataset,
            eval_dataset,
            transfer_dataset,
            sargs,
            sargs.temp_folder,
        )
    accelerator.wait_for_everyone()
    train_dataset, eval_dataset, transfer_dataset = load_tables(
        sargs, sargs.temp_folder
    )

    # Adaboost main loop
    for t in range(1, sargs.adaboost_rounds + 1):
        # Adaboost model training
        weak_model = train(
            sargs.model_name, train_dataset, eval_dataset, t, sargs, accelerator
        )

        if accelerator.is_main_process:
            # Evaluate weak model
            train_dataset = predict_token_train(weak_model, train_dataset, t, sargs)
            eval_dataset = predict_token_eval(
                weak_model, train_dataset, eval_dataset, t, sargs=sargs
            )
            eval_dataset = predict_option_eval(
                weak_model, train_dataset, eval_dataset, t, sargs=sargs
            )
            transfer_dataset = predict_option_eval(
                weak_model, train_dataset, transfer_dataset, t, sargs=sargs
            )
            del weak_model
            torch.cuda.empty_cache()

            save_tables(
                train_dataset,
                eval_dataset,
                transfer_dataset,
                sargs,
                sargs.temp_folder,
            )
        accelerator.wait_for_everyone()
        train_dataset, eval_dataset, transfer_dataset = load_tables(
            sargs, sargs.temp_folder
        )

        # Strong model training
        strong_model = train(
            sargs.strong_model_name,
            transfer_dataset,
            eval_dataset,
            t,
            sargs,
            accelerator,
            mode="strong",
        )
        if accelerator.is_main_process:
            # Evaluate strong model
            eval_dataset = predict_token_eval(
                strong_model, None, eval_dataset, t, sargs=sargs, mode="strong"
            )
            eval_dataset = predict_option_eval(
                strong_model, None, eval_dataset, t, sargs=sargs, mode="strong"
            )
            del strong_model
            torch.cuda.empty_cache()
            save_tables(
                train_dataset,
                eval_dataset,
                transfer_dataset,
                sargs,
                sargs.temp_folder,
            )
        accelerator.wait_for_everyone()
        train_dataset, eval_dataset, transfer_dataset = load_tables(
            sargs, sargs.temp_folder
        )

    if accelerator.is_main_process:
        # Save
        save_tables(
            train_dataset,
            eval_dataset,
            transfer_dataset,
            sargs,
            os.path.join(sargs.w2s_folder, "tables"),
        )
        save_results(train_dataset, eval_dataset, transfer_dataset, sargs)
    accelerator.wait_for_everyone()


if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    sargs = parser.parse_args_into_dataclasses()[0]
    set_seed(sargs.seed)

    # Validate arguments
    assert not (
        (not sargs.is_token_based_error) and sargs.is_weight_by_token
    ), "Cannot weight by token while not using token-based error"
    assert not (
        (sargs.token_prob_window_size > 1) and (not sargs.is_weight_by_token)
    ), "Cannot use token probability smoothing while not weighting by token"

    if os.path.exists(
        os.path.join(sargs.w2s_folder, "results", format_run_name(sargs) + ".csv")
    ):
        print("Results already exist, skipping...")
    else:
        main(sargs)
