import os
import time
import argparse
from pathlib import Path

import torch
import numpy as np
from t5_continual import T5ContinualLearner

def setup_environment(cache_dir: str):
    if cache_dir:
        os.environ['HF_HOME'] = cache_dir
        os.environ['TRANSFORMERS_CACHE'] = cache_dir

def prepare_output_dir(save_dir: Path, save_name: str) -> Path:
    out_dir = save_dir / save_name
    out_dir.mkdir(parents=True, exist_ok=True)
    return out_dir

def compute_eval_interval(num_epochs: int) -> int:
    if num_epochs <= 50:
        return 1
    if num_epochs <= 200:
        return 5
    return 10

def report_metrics(start_time: float, output_dir: Path):
    elapsed = time.time() - start_time
    print(f"Elapsed time: {elapsed:.2f} seconds")
    print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 2**20:.2f} MB")
    print(f"Peak GPU memory: {torch.cuda.max_memory_allocated() / 2**20:.2f} MB")
    print(f"Results saved in {output_dir}")

def main(args):
    start_time = time.time()
    setup_environment(args.cache_dir)

    output_dir = prepare_output_dir(Path(args.save_dir), args.save_name)

    learner = T5ContinualLearner(
        model_name=args.model_name,
        cache_dir=args.cache_dir or None,
        task_list=args.task_list,
        batch_size=args.batch_size,
        select_k_per_class=args.select_k_per_class,
        pre_processed=bool(args.pre_processed),
        prefix_len=args.prefix_len,
        freeze_weights=args.freeze_weights,
        freeze_except=args.freeze_except,
        lr=args.lr,
        seq_len=args.seq_len,
        early_stopping=args.early_stopping,
        prefix_MLP=args.prefix_MLP,
        prefix_path=args.prefix_path or None,
        mlp_layer_norm=args.mlp_layer_norm,
        bottleneck_size=args.bottleneck_size,
        get_test_subset=args.get_test_subset,
        memory_perc=args.memory_perc,
    )

    if not args.get_test_subset:
        print("Skipping test-subset creation")

    if args.multitask:
        print("Running multi-task training")
        results = learner.multi_task_training(num_epochs=args.num_epochs,
                                               save_path=str(output_dir))
    else:
        eval_every = compute_eval_interval(args.num_epochs)
        results = learner.train_continual(
            task_list=args.task_list,
            epochs=args.num_epochs,
            save_path=str(output_dir),
            progressive=args.progressive,
            eval_every_N=eval_every,
            test_eval_after_every_task=args.test_eval_after_every_task,
            data_replay_freq=args.data_replay_freq,
        )
        # save outputs
        np.save(output_dir / "results_dict.npy", results)
        np.save(output_dir / "prompts.npy",
                learner.previous_prompts.detach().cpu().numpy())

    report_metrics(start_time, output_dir)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Continual Prompt Tuning with Gradient Selection"
    )
    parser.add_argument("--save_dir",     type=str,   default=".", help="Base output path")
    parser.add_argument("--save_name",    type=str,   required=True, help="Subfolder name")
    parser.add_argument("--task_list",    nargs="+",  required=True)
    parser.add_argument("--model_name",   type=str,   default="t5-base")
    parser.add_argument("--cache_dir",    type=str,   default="")
    parser.add_argument("--num_epochs",   type=int,   default=5)
    parser.add_argument("--multitask",    action="store_true")
    parser.add_argument("--batch_size",   type=int,   default=8)
    parser.add_argument("--seq_len",      type=int,   default=512)
    parser.add_argument("--prefix_len",   type=int,   default=10)
    parser.add_argument("--prefix_path",  type=str,   default="")
    parser.add_argument("--lr",           type=float, default=0.3)
    parser.add_argument("--memory_perc",  type=float, default=0.01)
    parser.add_argument("--data_replay_freq", type=float, default=-1)
    parser.add_argument("--select_k_per_class", type=int, default=-1)
    parser.add_argument("--pre_processed", action="store_true")
    parser.add_argument("--test_eval_after_every_task", action="store_true")
    parser.add_argument("--progressive",  action="store_true", default=True)
    parser.add_argument("--freeze_weights", action="store_true")
    parser.add_argument("--freeze_except", type=str, default="xxxxxxx")
    parser.add_argument("--get_test_subset", action="store_true", default=True)
    parser.add_argument("--early_stopping", action="store_true", default=True)
    parser.add_argument("--prefix_MLP",    type=str,   default="None")
    parser.add_argument("--mlp_layer_norm", action="store_true", default=True)
    parser.add_argument("--bottleneck_size", type=int, default=800)

    main(parser.parse_args())
