import numpy as np
import helper
import time
import os
from copy import deepcopy
import wandb
from helper import Args


def train_and_infer(model, train_dataloaders, test_dataloaders, trainer):
    args = Args().get_args()
    accuracies_class_il = []
    accuracies_task_il = []
    confusion_matrices_class = []
    confusion_matrices_task = []
    sequential_accuracies = []
    task_predicted_counts = None
    all_probabilities = None
    all_predictions = None
    all_labels = None
    peft_model = None
    finetuned_model = None

    if args.model == "vit":
        peft_model = deepcopy(model)
        trainer.lora_builder.add_dynamic_lora(peft_model, args.target_modules)
        helper.log_and_print(f"Model after dynamic LoRA addition: {peft_model}", args.logger, False)

    for task_id, train_dataloader in enumerate(train_dataloaders):
        helper.log_and_print(f"Training for task {task_id + 1}", args.logger, args.verbose)
        num_params = sum(p.numel() for p in model.parameters())
        helper.log_and_print(f"Number of total parameters at task {task_id + 1}: {num_params}", args.logger, args.verbose)
        if task_id > 0 and args.model != "vit":
            model = trainer.lora_builder.add_new_subnet_for_new_task(model)

        helper.log_and_print(f"Training task {task_id+1} for {args.n_epochs} epochs.", args.logger, args.verbose)
        
        if args.model == "vit":
            finetuned_model = deepcopy(model)
        
        results = trainer.train_loop(
            model=finetuned_model if args.model == "vit" else model,
            task_id=task_id,
            train_dataloader=train_dataloader,
            test_dataloaders=test_dataloaders
        )

        if args.forward_transfer:
            if task_id > 0 and  args.model != "vit":
                helper.log_and_print(f"Beginning LoRA training for task {task_id+1} for {args.n_epochs} epochs!", args.logger, args.verbose)
                results = trainer.finetune_lora(
                    model=model,
                    task_id=task_id,
                    train_dataloader=train_dataloader,
                    test_dataloaders=test_dataloaders
                )
            elif args.model == "vit":
                assert peft_model is not None and finetuned_model is not None
                helper.log_and_print(f"Adding LoRA for task {task_id + 1}!", args.logger, args.verbose)
                peft_model = trainer.lora_builder.build_lora_model(peft_model, finetuned_model, args.target_modules, task_id)
                trainer.lora_builder.freeze_parameters_for_task(peft_model, args.target_modules, task_id)
                helper.log_and_print(f"Model after LoRA addition: {peft_model}", args.logger, False)

                # Print trainable parameters
                num_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
                helper.log_and_print(f"Number of trainable parameters after LoRA addition: {num_params}", args.logger, args.verbose)

                results = trainer.finetune_lora(
                    model=peft_model,
                    task_id=task_id,
                    train_dataloader=train_dataloader,
                    test_dataloaders=test_dataloaders
                )
            
            
        (
            accuracy_class_il,
            accuracies_per_task_class_il,
            accuracies_per_task_task_il,
            confusion_matrix_class,
            confusion_matrix_task,
            (all_probabilities, all_predictions, all_labels),
            task_predicted_counts,
        ) = results

        accuracies_class_il.append(accuracies_per_task_class_il)
        sequential_accuracies.append(accuracy_class_il)
        accuracies_task_il.append(accuracies_per_task_task_il)
        confusion_matrices_class.append(confusion_matrix_class)
        confusion_matrices_task.append(confusion_matrix_task)
   
    if args.use_wandb:
        wandb.log({"Sequential Acc": np.mean(sequential_accuracies)})

    stability = np.mean(accuracies_class_il[-1])
    plasticity = np.mean(
        [accuracies_class_il[k][k] for k in range(len(accuracies_class_il) - 1)]
    )

    output_dir = os.path.join("./results", str(int(time.time())))
    os.makedirs(output_dir, exist_ok=True)

    save_results(
        accuracies_class_il=accuracies_class_il,
        task_predicted_counts=task_predicted_counts,
        confusion_matrix_task_last=confusion_matrices_task[-1],
        all_probabilities=all_probabilities,
        all_predictions=all_predictions,
        all_labels=all_labels,
        sequential_accuracies=sequential_accuracies,
        stability=stability,
        plasticity=plasticity,
        output_dir=output_dir,
    )

    return peft_model if args.model == "vit" else model


def save_results(accuracies_class_il, task_predicted_counts, confusion_matrix_task_last,
                 all_probabilities, all_predictions, all_labels, sequential_accuracies,
                 stability, plasticity, output_dir):
    helper.save_task_performance(accuracies_class_il, output_dir=output_dir)
    helper.save_task_probabilities(task_predicted_counts, output_dir=output_dir)
    helper.save_confusion_matrix(confusion_matrix_task_last, output_dir=output_dir)
    helper.save_calibration(
        all_probabilities, all_predictions, all_labels, output_dir=output_dir
    )
    helper.save_sequential_performance(sequential_accuracies, output_dir=output_dir)
    helper.save_stability_plasticity(stability, plasticity, output_dir=output_dir)
