import os
import random
import numpy as np
import torch
from torch.optim import AdamW
from transformers import AutoModelForImageClassification, TrainingArguments, TrainerCallback, Trainer
from transformers.trainer_utils import get_last_checkpoint
import shutil
import optuna
import json
import fire

import sys; sys.path.append("src")
from utils import print_trainable_parameters, prepare_image_classification_data, image_classification_metric
from peft import C3AConfig, LoraConfig, get_peft_model

CONFIG_DICT={
    "c3a": C3AConfig,
    "lora": LoraConfig,
}

def main(
    model_name_or_path: str = "google/vit-base-patch16-224-in21k",
    dataset_name: str = "pets",
    mode: str = "c3a",
    c3a_block_size: int = 768,
    c3a_init_weight: str = False,
    lora_r: int = 16,
    lora_dropout: float = 0.1,
    batch_size: int = 64,
    n_epoch: int = 10,
    n_trial: int = 100,
    results_dir: str = "results",
    eval_only: bool = False,
    head_lr: float = 5e-2,
    other_lr: float = 1e-4,
    weight_decay: float = 5e-4,
    use_bf16: bool = True,
):
    
    train_ds, val_ds, test_ds, collate_fn, label2id, id2label, image_processor = prepare_image_classification_data(dataset_name, model_name_or_path)

    def get_model(**kwargs):
        model = AutoModelForImageClassification.from_pretrained(
            model_name_or_path,
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
            device_map='cuda' if torch.cuda.is_available() else 'cpu',
        )
        print_trainable_parameters(model)

        if mode in ["c3a", "lora"]:
            config_kwargs = {
                "target_modules": ["query", "value"],
                "modules_to_save": ["classifier"],
            }
            if mode == "c3a":
                config_kwargs.update({
                    "block_size": c3a_block_size,
                    "init_weights": c3a_init_weight,
                })
            elif mode == "lora":
                config_kwargs.update({
                    "r": lora_r,
                    "lora_alpha": lora_r,
                    "lora_dropout": lora_dropout,
                })
            config = CONFIG_DICT[mode](**config_kwargs)
            model = get_peft_model(model, config)
        elif mode == "head":
            model.requires_grad_(False)
            model.classifier.requires_grad_(True)
        elif mode == "full":
            model.requires_grad_(True)
        else:
            raise ValueError(f"Unknown mode: {mode}")
        print_trainable_parameters(model)
        return model

    model_name = model_name_or_path.split("/")[-1]
    if mode == "c3a":
        save_id = f'{model_name}-{mode}-init_{c3a_init_weight}-b{c3a_block_size}-{dataset_name}'
    elif mode == "lora":
        save_id = f'{model_name}-{mode}-r{lora_r}-d{lora_dropout}-{dataset_name}'
    elif mode in ["head", "full"]:
        save_id = f'{model_name}-{mode}-{dataset_name}'
    else:
        raise ValueError(f"Unknown mode: {mode}")
    if use_bf16:
        save_id += "-bf16"

    def objective(trial):
        kwargs = {}
        head_lr = trial.suggest_float("head_lr", 1e-3, 1e-1, log=True)
        other_lr = trial.suggest_float("other_lr", 1e-5, 1e5, log=True)
        weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)

        model = get_model(**kwargs)

        optimizer = AdamW([
                {"params": [p for n, p in model.named_parameters() if "classifier" in n and p.requires_grad], "lr": head_lr},
                {"params": [p for n, p in model.named_parameters() if "classifier" not in n and p.requires_grad], "lr": other_lr},
            ],
            weight_decay=weight_decay
            )
        
        save_path = os.path.join(results_dir, save_id, f"run-{trial.number}")

        class OptunaPruneCallback(TrainerCallback):
            def on_evaluate(self, args, state, control, **kwargs):
                current_score = state.best_metric if state.best_metric is not None else kwargs["metrics"]["eval_accuracy"]
                trial.report(current_score, state.epoch)
                if trial.should_prune():
                    shutil.rmtree(save_path)
                    raise optuna.TrialPruned()

        args = TrainingArguments(
            save_path,
            remove_unused_columns=False,
            eval_strategy="epoch",
            save_strategy="epoch",
            per_device_train_batch_size=batch_size,
            gradient_accumulation_steps=4,
            per_device_eval_batch_size=batch_size,
            bf16=use_bf16,
            num_train_epochs=n_epoch,
            logging_steps=10,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            label_names=["labels"],
            dataloader_num_workers=4,
            save_total_limit=1,
            save_only_model=True,
        )
    
        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=train_ds,
            eval_dataset=val_ds,
            tokenizer=image_processor,
            compute_metrics=image_classification_metric,
            data_collator=collate_fn,
            optimizers=(optimizer, None),
            callbacks=[OptunaPruneCallback()],
        )

        try:
            trainer.train()
        except ValueError: # metric input is nan
            shutil.rmtree(save_path)
            raise optuna.TrialPruned()
    
        eval_acc = trainer.evaluate(val_ds)["eval_accuracy"]
        shutil.rmtree(save_path)
        return eval_acc

    def evaluate_different_seeds(params, seeds=[7, 77, 777, 7777, 77777]):
        # evaluate the model on test dataset with different seeds
        test_metrics = []
        kwargs = {}
        head_lr = params["head_lr"]
        other_lr = params["other_lr"]
        weight_decay = params["weight_decay"]

        trainable_params, all_param = 0, 0
        for seed in seeds:
            random.seed(seed), np.random.seed(seed)
            torch.manual_seed(seed), torch.cuda.manual_seed(seed), torch.cuda.manual_seed_all(seed)
            
            model = get_model(**kwargs)

            save_path = os.path.join(results_dir, save_id, f"testrun-{seed}")
            test_only = os.path.exists(save_path)

            if trainable_params == 0:
                trainable_params, all_param = print_trainable_parameters(model)

            optimizer = AdamW([
                    {"params": [p for n, p in model.named_parameters() if "classifier" in n and p.requires_grad], "lr": head_lr},
                    {"params": [p for n, p in model.named_parameters() if "classifier" not in n and p.requires_grad], "lr": other_lr},
                ],
                weight_decay=weight_decay
                )
            
            args = TrainingArguments(
                save_path,
                remove_unused_columns=False,
                eval_strategy="epoch",
                save_strategy="epoch",
                per_device_train_batch_size=batch_size,
                gradient_accumulation_steps=4,
                per_device_eval_batch_size=batch_size,
                bf16=use_bf16,
                num_train_epochs=n_epoch,
                logging_steps=10,
                load_best_model_at_end=True,
                metric_for_best_model="accuracy",
                label_names=["labels"],
                dataloader_num_workers=4,
                save_total_limit=1,
                save_only_model=True,
                seed=seed,
            )

            trainer = Trainer(
                model=model,
                args=args,
                train_dataset=train_ds,
                eval_dataset=val_ds,
                tokenizer=image_processor,
                compute_metrics=image_classification_metric,
                data_collator=collate_fn,
                optimizers=(optimizer, None),
            )

            if not test_only:
                trainer.train()
            else:
                trainer._load_from_checkpoint(get_last_checkpoint(save_path))
            test_metrics.append(trainer.evaluate(test_ds)["eval_accuracy"])

        return np.mean(test_metrics), np.std(test_metrics), trainable_params, all_param

    os.makedirs(os.path.join(results_dir, save_id), exist_ok=True)
    if not eval_only:
        pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=1)
        study = optuna.create_study(
                direction="maximize",
                storage=f"sqlite:///{os.path.join(results_dir, save_id, 'optuna.db')}",
                study_name=save_id,
                pruner=pruner,
            )
        study.optimize(objective, n_trials=n_trial)

        study.trials_dataframe().to_csv(os.path.join(results_dir, save_id, "optuna.csv"))
        best_trial = study.best_trial
        params = best_trial.params

    else:
        eval_trial_path = os.path.join(results_dir, save_id, "optuna.db")
        if os.path.exists(eval_trial_path):
            study = optuna.load_study(
                study_name=save_id,
                storage=f"sqlite:///{eval_trial_path}",
            )
            best_trial = study.best_trial
            params = best_trial.params
        else:
            params = {
                "head_lr": head_lr,
                "other_lr": other_lr,
                "weight_decay": weight_decay,
            }
    test_metrics, test_metrics_std, trainable_params, all_param = evaluate_different_seeds(params)

    results = {
        "params": params,
        "test_metrics": {
            "mean": test_metrics,
            "std": test_metrics_std,
        }
    }
    if mode != "head":
        results["compression"] = {
            "trainable_parameters": trainable_params,
            "all_parameters": all_param,
            "compression_ratio": 100 * trainable_params / all_param,
        }

    with open(os.path.join(results_dir, save_id, "results.json"), "w") as f:
        json.dump(results, f)


if __name__ == "__main__":
    fire.Fire(main)