from torch.optim.lr_scheduler import StepLR
from train import schedulers
import torch
from tqdm import tqdm
import helper
from train.eval import task_il_eval, class_il_eval
import wandb
from torch.cuda.amp import GradScaler, autocast
from helper import Args


class Trainer:
    def __init__(self, criterion, lora_builder, device):
        self.criterion = criterion
        self.args = Args().get_args()
        self.lora_builder = lora_builder
        self.device = device

    def set_scheduler(self, optimizer):
        if self.args.scheduler == "cosine":
            scheduler = schedulers.CosineSchedule(optimizer, K=self.args.n_epochs)
        else:
            scheduler = StepLR(optimizer, step_size=22, gamma=0.7)
        return scheduler

    def evaluate(self, model, test_dataloaders, task_id):
        accuracy_til, acc_per_task_til = task_il_eval(
            model, self.args.n_classes_per_task, test_dataloaders[: task_id + 1], self.device, self.args.model
        )

        (
            accuracy_cil,
            acc_per_task_cil,
            conf_matrix,
            conf_matrix_task,
            (all_probs, all_preds, all_labels),
            task_predicted_counts,
        ) = class_il_eval(
            model, self.args.n_classes_per_task, test_dataloaders[: task_id + 1],
            task_id, self.device, self.args.model, self.args.use_torch_amp
        )

        helper.log_and_print(f"Task {task_id + 1} - TASK-IL Test Accuracy: {accuracy_til:.2f}%", self.args.logger,
                             self.args.verbose)
        helper.log_and_print(f"Accuracy per task TASK-IL: {acc_per_task_til}", self.args.logger, self.args.verbose)
        helper.log_and_print(f"Task {task_id + 1} - Class-IL Test Accuracy: {accuracy_cil:.2f}%", self.args.logger,
                             self.args.verbose)
        helper.log_and_print(f"Accuracy per task Class-IL: {acc_per_task_cil}", self.args.logger, self.args.verbose)

        if self.args.use_wandb:
            wandb.log(
                {
                    "Task-IL Test Accuracy": accuracy_til,
                    "Class-IL Test Accuracy": accuracy_cil,
                    "Accuracy per task TASK-IL": acc_per_task_til,
                    "Accuracy per task Class-IL": acc_per_task_cil,
                }
            )

        return (accuracy_cil, acc_per_task_cil, acc_per_task_til, conf_matrix,
                conf_matrix_task, (all_probs, all_preds, all_labels), task_predicted_counts)

    @staticmethod
    def weight_renormalization(model, maximum_neuron_activations, task_id):
        normalizer = helper.max_excluding_outliers_iqr(
            maximum_neuron_activations.cpu().detach()
        )
        model.fcs[task_id].weight.data = (5 / normalizer) * model.fcs[
            task_id
        ].weight.data
        model.fcs[task_id].bias.data = (5 / normalizer) * model.fcs[task_id].bias.data
        return model

    def train_loop(self, model,  task_id, train_dataloader, test_dataloaders):
        model.train()
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr)
        scheduler = self.set_scheduler(optimizer)
        # If self.args.use_torch_amp is True, use torch cuda amp
        scaler = GradScaler() if self.args.use_torch_amp else None

        for epoch in range(self.args.n_epochs):
            running_loss = 0.0
            maximum_neuron_activations = torch.tensor([], device=self.device)
            with tqdm(train_dataloader, desc=f"Epoch {epoch + 1}") as t:
                for images, labels in t:
                    images, labels = images.to(self.device), labels.to(self.device)
                    optimizer.zero_grad()
                    model.zero_grad()

                    if scaler is not None:
                        with autocast():
                            outputs = model(images, task_id, use_lora=False, training=True)
                            mapped_labels = labels % self.args.n_classes_per_task
                            loss = self.criterion(outputs, mapped_labels)

                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        outputs = model(images, task_id, use_lora=False, training=True)
                        mapped_labels = labels % self.args.n_classes_per_task
                        loss = self.criterion(outputs, mapped_labels)
                        loss.backward()
                        optimizer.step()

                    running_loss += loss.item()
                    t.set_description(
                        f"Epoch {epoch + 1}, Loss: {running_loss / (t.n + 1):.4f}"
                    )

                    if epoch == self.args.n_epochs - 1:
                        max_out_samples = torch.max(outputs, 1)[0].detach()
                        maximum_neuron_activations = torch.cat(
                            (maximum_neuron_activations, max_out_samples)
                        )
            scheduler.step()

        if task_id > 0 and self.args.forward_transfer and self.args.model != "vit":
            model = self.lora_builder.build_lora_model(model)
            model = self.lora_builder.remove_old_subnet_from_prev_task(model)
        else:
            if self.args.weight_renorm:
                model = self.weight_renormalization(model, maximum_neuron_activations, task_id)
            return self.evaluate(model, test_dataloaders, task_id)


    def finetune_lora(self, model, task_id, train_dataloader, test_dataloaders):
        model.train()
        if self.args.model != "vit":
            for name, module in model.named_modules():
                for param in module.parameters():
                    param.requires_grad = True

        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr
        )

        for epoch in range(self.args.n_epochs_lora):
            running_loss = 0.0
            maximum_neuron_activations = torch.tensor([], device=self.device)

            with tqdm(train_dataloader, desc=f"LoRA Epoch {epoch + 1}") as t:
                for images, labels in t:
                    images, labels = images.to(self.device), labels.to(self.device)
                    optimizer.zero_grad()
                    outputs = model(images, task_id, use_lora=True, training=False)
                    mapped_labels = labels % self.args.n_classes_per_task
                    loss = self.criterion(outputs, mapped_labels)

                    if epoch == self.args.n_epochs - 1:
                        max_out_samples = torch.max(outputs, 1)[0].detach()
                        maximum_neuron_activations = torch.cat(
                            (maximum_neuron_activations, max_out_samples)
                        )

                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()
                    t.set_description(
                        f"LoRA Epoch {epoch + 1}, Loss: {running_loss / (t.n + 1):.4f}"
                    )

        if self.args.weight_renorm:
            print('Weights renormalizing after LoRA training!')
            model = self.weight_renormalization(model, maximum_neuron_activations, task_id)

        return self.evaluate(model, test_dataloaders, task_id)
