import os
import fire
import random
import numpy as np
import torch
from torch.optim import AdamW
from transformers import TrainingArguments, Trainer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, TrainerCallback
from transformers.trainer_utils import get_last_checkpoint
import optuna
import shutil
import json

import sys; sys.path.append("src")
from utils import print_trainable_parameters, prepare_glue
from peft import C3AConfig, LoraConfig, VeraConfig, BOFTConfig, IA3Config, get_peft_model
from optuna_pruner import MultiplePruner

CONFIG_DICT={
    "c3a": C3AConfig,
    "lora": LoraConfig,
    "vera": VeraConfig,
    "boft": BOFTConfig,
    "ia3": IA3Config,
}

def main(
    model_name_or_path: str = "FacebookAI/roberta-base",
    task: str = "cola",
    max_length: int = 512,
    padding_mode: str = "longest",
    mode: str = "c3a",
    c3a_block_size: int = 768,
    c3a_init_weight: str = False,
    lora_r: int = 8,
    lora_dropout: float = 0.1,
    vera_r: int = 1024,
    boft_m: int = 2,
    boft_b: int = 8,
    n_trial: int = 40,
    eval_only: bool = False,
    head_lr: float = 1e-3,
    other_lr: float = 1e-3,
    results_dir: str = 'results',
    use_bf16: bool = True,
):
    
    ds_related, metric_related, training_hyperparameters, tokenizer = prepare_glue(task, model_name_or_path, max_length, padding_mode)
    train_ds, val_ds, test_ds, collate_fn = ds_related.values()
    metric_fn, metric_name, metric_lower_bond = metric_related.values()
    n_epoch, batch_size = training_hyperparameters.values()

    def get_model(**kwargs):
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path,
            device_map='cuda' if torch.cuda.is_available() else 'cpu',
            return_dict=True,
            num_labels=len(train_ds.unique("labels")) if task != "stsb" else 1,
        )
        print_trainable_parameters(model)

        if mode in ["c3a", "lora", "boft", "ia3", "vera"]:
            config_kwargs = {
                "target_modules": ["query", "value"],
                "modules_to_save": ["classifier"],
            }
            if mode == "c3a":
                config_kwargs.update({
                    "block_size": c3a_block_size,
                    "init_weights": c3a_init_weight,
                })
            elif mode == "lora":
                config_kwargs.update({
                    "r": lora_r,
                    "lora_alpha": lora_r,
                    "lora_dropout": lora_dropout,
                })
            elif mode == "vera":
                config_kwargs.update({
                    "r": vera_r,
                    "save_projection": False,
                    "d_initial": 0.1,
                })
            elif mode == "boft":
                config_kwargs.update({
                    "boft_block_size": boft_b,
                    "boft_n_butterfly_factor": boft_m,
                    "boft_dropout": 0.1,
                })
            elif mode == "ia3":
                config_kwargs["target_modules"].extend(["key", "output.dense", "intermediate.dense"])
            config = CONFIG_DICT[mode](**config_kwargs)
            model = get_peft_model(model, config)
        elif mode == "bitfit":
            model.requires_grad_(False)
            for name, param in model.named_parameters():
                if "bias" in name:
                    param.requires_grad_(True)
        elif mode == "full":
            model.requires_grad_(True)
        else:
            raise ValueError(f"Unknown mode: {mode}")
        print_trainable_parameters(model)
        return model

    model_name = model_name_or_path.split("/")[-1]
    if mode == "c3a":
        save_id = f'{model_name}-{mode}-init_{c3a_init_weight}-b{c3a_block_size}-{task}'
    elif mode == "vera":
        save_id = f'{model_name}-{mode}-r{vera_r}-{task}'
    elif mode == "boft":
        save_id = f'{model_name}-{mode}-m{boft_m}-b{boft_b}-{task}'
    elif mode == "lora":
        save_id = f'{model_name}-{mode}-r{lora_r}-d{lora_dropout}-{task}'
    elif mode in ["full", "bitfit", "ia3"]:
        save_id = f'{model_name}-{mode}-{task}'
    else:
        raise ValueError(f"Unknown mode: {mode}")
    if use_bf16:
        save_id += "-bf16"

    def objective(trial):
        kwargs = {}
        head_lr = trial.suggest_float("head_lr", 1e-8, 1e-1, log=True)
        other_lr = trial.suggest_float("other_lr", 1e-5, 1e1, log=True)

        model = get_model(**kwargs)
        
        optimizer = AdamW([
                {"params": [p for n, p in model.named_parameters() if "classifier" in n and p.requires_grad], "lr": head_lr},
                {"params": [p for n, p in model.named_parameters() if "classifier" not in n and p.requires_grad], "lr": other_lr},
            ],
            weight_decay=0.
            )
        scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=int(np.ceil(0.06 * len(train_ds) * n_epoch / batch_size)),
                        num_training_steps=int(np.ceil(len(train_ds) * n_epoch / batch_size))
                    )
        
        save_path = os.path.join(results_dir, save_id, f"run-{trial.number}")

        class OptunaPruneCallback(TrainerCallback):
            def on_evaluate(self, args, state, control, **kwargs):
                current_score = state.best_metric if state.best_metric is not None else kwargs["metrics"][f"eval_{metric_name}"]
                trial.report(current_score, state.epoch)
                if trial.should_prune():
                    shutil.rmtree(save_path)
                    raise optuna.TrialPruned()

        args = TrainingArguments(
            save_path,
            remove_unused_columns=False,
            eval_strategy="epoch",
            save_strategy="epoch",
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            bf16=use_bf16,
            num_train_epochs=n_epoch,
            logging_steps=10,
            label_names=["labels"],
            metric_for_best_model=metric_name,
            greater_is_better=True,
            load_best_model_at_end=True,
            save_total_limit=1,
        )

        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            tokenizer=tokenizer,
            compute_metrics=metric_fn,
            data_collator=collate_fn,
            optimizers=(optimizer, scheduler),
            callbacks=[OptunaPruneCallback()],
        )

        try:
            trainer.train()
        except ValueError: # metric input is nan
            shutil.rmtree(save_path)
            raise optuna.TrialPruned()

        metric_value = trainer.evaluate(val_ds)[f"eval_{metric_name}"]
        shutil.rmtree(save_path)
        return metric_value

    def evaluate_different_seeds(params, seeds=[7, 77, 777, 7777, 77777]):
        # evaluate the model on test dataset with different seeds
        test_metrics = []
        kwargs = {}
        head_lr = params["head_lr"]
        other_lr = params["other_lr"]

        trainable_params, all_param = 0, 0
        for seed in seeds:
            random.seed(seed), np.random.seed(seed)
            torch.manual_seed(seed), torch.cuda.manual_seed(seed), torch.cuda.manual_seed_all(seed)
            
            model = get_model(**kwargs)
            
            save_path = os.path.join(results_dir, save_id, f"testrun-{seed}")
            test_only = os.path.exists(save_path)

            if trainable_params == 0:
                trainable_params, all_param = print_trainable_parameters(model)

            optimizer = AdamW([
                    {"params": [p for n, p in model.named_parameters() if "classifier" in n and p.requires_grad], "lr": head_lr},
                    {"params": [p for n, p in model.named_parameters() if "classifier" not in n and p.requires_grad], "lr": other_lr},
                ],
                weight_decay=0.
                )
            scheduler = get_linear_schedule_with_warmup(
                    optimizer,
                    num_warmup_steps=int(np.ceil(0.06 * len(train_ds) * n_epoch / batch_size)),
                    num_training_steps=int(np.ceil(len(train_ds) * n_epoch / batch_size))
                )

            args = TrainingArguments(
                save_path,
                remove_unused_columns=False,
                eval_strategy="epoch",
                save_strategy="epoch",
                per_device_train_batch_size=batch_size,
                per_device_eval_batch_size=batch_size,
                bf16=use_bf16,
                num_train_epochs=n_epoch,
                logging_steps=10,
                label_names=["labels"],
                metric_for_best_model=metric_name,
                greater_is_better=True,
                load_best_model_at_end=True,
                save_total_limit=1,
            )

            trainer = Trainer(
                model=model,
                args=args,
                train_dataset=train_ds,
                eval_dataset=val_ds,
                tokenizer=tokenizer,
                compute_metrics=metric_fn,
                data_collator=collate_fn,
                optimizers=(optimizer, scheduler),
            )

            if not test_only:
                trainer.train()
            else:
                trainer._load_from_checkpoint(get_last_checkpoint(save_path))
            test_metric = trainer.evaluate(test_ds)[f"eval_{metric_name}"]
            test_metrics.append(test_metric)

        return np.mean(test_metrics), np.std(test_metrics), trainable_params, all_param

    os.makedirs(os.path.join(results_dir, save_id), exist_ok=True)
    if not eval_only:
        pruner = MultiplePruner([
            optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=1),
            optuna.pruners.ThresholdPruner(lower=metric_lower_bond, n_warmup_steps=5)
        ], pruning_condition="any")
        
        study = optuna.create_study(
                direction="maximize",
                storage=f"sqlite:///{os.path.join(results_dir, save_id, 'optuna.db')}",
                study_name=save_id,
                pruner=pruner,
            )
        study.optimize(objective, n_trials=n_trial)

        study.trials_dataframe().to_csv(os.path.join(results_dir, save_id, "optuna.csv"))
        best_trial = study.best_trial
        params = best_trial.params

    else:
        eval_trial_path = os.path.join(results_dir, save_id, "optuna.db")
        if os.path.exists(eval_trial_path):
            study = optuna.load_study(
                study_name=save_id,
                storage=f"sqlite:///{eval_trial_path}",
            )
            best_trial = study.best_trial
            params = best_trial.params
        else:
            params = {
                "head_lr": head_lr,
                "other_lr": other_lr,
            }
    test_metrics, test_metrics_std, trainable_params, all_param = evaluate_different_seeds(params)

    results = {
        "params": params,
        "test_metrics": {
            "mean": test_metrics,
            "std": test_metrics_std,
        }
    }
    if mode != "head":
        results["compression"] = {
            "trainable_parameters": trainable_params,
            "all_parameters": all_param,
            "compression_ratio": 100 * trainable_params / all_param,
        }

    with open(os.path.join(results_dir, save_id, "results.json"), "w") as f:
        json.dump(results, f)
    
if __name__ == "__main__":
    fire.Fire(main)