import torch
from avalanche.evaluation.metrics import accuracy_metrics, forgetting_metrics
from avalanche.logging import InteractiveLogger, TextLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training.supervised import Naive
from configs import default
from utils.buffers import CustomClassBalancedBuffer
from utils.data import load_benchmark
from utils.plugins import ExperienceReplay
from utils.training import get_free_gpu_idx, load_model, set_seed

args = default.get_args()

DATASET = args.dataset
SEED = args.seed
AUGMENTATION = args.augmentation
MODEL = args.model
set_seed(SEED)
print(f"Running experiment with dataset: {DATASET}, model: {MODEL}, seed: {SEED}, augmentation: {AUGMENTATION}")

benchmark, n_classes, train_transform, eval_transform = load_benchmark(benchmark_name=DATASET, augmentation=AUGMENTATION, seed=SEED)
device = torch.device(f"cuda:{get_free_gpu_idx()}" if torch.cuda.is_available() else "cpu")
model, optimizer, early_stopping_plugin, scheduler_plugin, criterion, mem_size = load_model(benchmark_name=DATASET, model_name=MODEL)
buffer = CustomClassBalancedBuffer(max_size=mem_size, adaptive_size=True)
custom_replay_plugin = ExperienceReplay(buffer, args=args, mem_batch_size=128)

interactive_logger = InteractiveLogger()
text_logger = TextLogger(open(f'./text_logs/ER_{DATASET}_seed{SEED}_log.txt', 'a'))
eval_plugin = EvaluationPlugin(
    accuracy_metrics(epoch=True, experience=True, stream=True),
    forgetting_metrics(experience=True, stream=True),
    loggers=[interactive_logger, text_logger],
    strict_checks=False)

cl_strategy = Naive(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_mb_size=128,
    train_epochs=200,
    eval_mb_size=128,
    plugins=[custom_replay_plugin, scheduler_plugin, early_stopping_plugin],
    evaluator=eval_plugin,
    device=device,
    eval_every=1,
)

results = []
print('Starting experiment...')
for experience_tr, experience_val in zip(benchmark.train_stream, benchmark.valid_stream):
    task_id = experience_tr.current_experience
    print(f"Start of experience: {task_id}")
    print(f"Current Classes: {experience_tr.classes_in_this_experience}")

    # # train the model on the current experience
    # # if val_exp=None, training set is used to update the memory buffer
    # # if memory_transform=None, no transformation is applied to the memory buffer during training
    res = cl_strategy.train(experience_tr, eval_streams=[experience_val], val_exp=None, memory_transform=None)
    print('Training completed')

    print('Evaluation on the test set...')
    results.append(cl_strategy.eval(benchmark.test_stream[:task_id+1]))

print()
for task_id, dict_result in enumerate(results):
    for task_id_test in range(task_id + 1):
        print(f'Task {task_id_test} - Accuracy: {dict_result[f"Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp00{task_id_test}"]:.3f}')
    if task_id > 0:
        for task_id_prev in range(task_id):
            print(f'Task {task_id_prev} - Forgetting: {dict_result[f"ExperienceForgetting/eval_phase/test_stream/Task000/Exp00{task_id_prev}"]:.3f}')
    print(f'Avg. Accuracy: {dict_result[f"Top1_Acc_Stream/eval_phase/test_stream/Task000"]:.3f}')
    print(f'Avg. Forgetting: {dict_result[f"StreamForgetting/eval_phase/test_stream"]:.3f}')
    print()