import logging
import numpy as np

import torch
from torch.nn import functional as F

from tqdm import tqdm
from methods.base import BaseLearner
from utils.toolkit import tensor2numpy
from utils.toolkit import FisherComputer


class EWC_LoRA(BaseLearner):

    def __init__(self, args):
        super().__init__(args)
        
        self.topk = 1
        self.omega_W = []
        self.gamma = 0.9
        self.ewc_weight = args["ewc_weight"]
        self.increment = args['increment']
        self.count_updates = 0

    def after_task(self):
        super().after_task()

        # Compute Fisher Information Matrix
        print("=== Update Importance Matrix ===")
        self.count_updates += 1
        fisher = FisherComputer(self.cur_task, self.network, self.train_loader, 
                                self.increment, F.cross_entropy, self.device)
        fisher_W = fisher.compute(max_batches=None)

        omega_W_bk = self.omega_W[:]
        self.omega_W = []

        new_a_params = filter(lambda p: getattr(p, '_is_new_a', False), self.network.parameters())
        new_b_params = filter(lambda p: getattr(p, '_is_new_b', False), self.network.parameters())
        for idx, (p_a, p_b) in enumerate(zip(new_a_params, new_b_params)):
            if len(omega_W_bk) != 0:
                self.omega_W.append(self.gamma * omega_W_bk[idx] + fisher_W[idx])
            else:
                self.omega_W.append(fisher_W[idx])

        # Merge LoRA and reset new LoRA
        self.network.accumulate_and_reset_lora()

    def _train_function(self, task, train_loader, optimizer, scheduler):
        prog_bar = tqdm(range(self.run_epoch))
        for _, epoch in enumerate(prog_bar):
            self.network.train()
            losses = 0.
            correct, total = 0, 0

            for i, (_, inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                mask = (targets >= self.known_classes).nonzero().view(-1)
                inputs = torch.index_select(inputs, 0, mask)
                targets = torch.index_select(targets, 0, mask)-self.known_classes

                # current loss
                logits = self.network(image=inputs, use_new=True)['logits']
                loss = F.cross_entropy(logits, targets)

                # regularization loss
                if self.count_updates != 0:
                    new_a_params = filter(lambda p: getattr(p, '_is_new_a', False), self.network.parameters())
                    new_b_params = filter(lambda p: getattr(p, '_is_new_b', False), self.network.parameters())
                    ewc_loss = 0.
                    for idx, (p_a, p_b) in enumerate(zip(new_a_params, new_b_params)):
                        delta_W = p_b @ p_a
                        ewc_term = self.omega_W[idx].type(torch.float32).to(self.device) * (delta_W ** 2)
                        ewc_loss += torch.sum(ewc_term)
                        
                    weighted_ewc_loss = self.ewc_weight/2. * ewc_loss
                    loss += weighted_ewc_loss
                    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses += loss.item()
                _, preds = torch.max(logits, dim=1)
                correct += preds.eq(targets.expand_as(preds)).cpu().sum()
                total += len(targets)

            scheduler.step()
            train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

            info = 'Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}'.format(
                self.cur_task, epoch + 1, self.run_epoch, losses / len(train_loader), train_acc)
            prog_bar.set_description(info)

        logging.info(info)


    def freeze_network(self):
        try:
            task_id = self.network.module._cur_task
        except AttributeError:
            task_id = self.network._cur_task

        target_suffix = f".{task_id}"
        unfrozen_keys = [
            f"classifier_pool{target_suffix}",
            f"lora_new_A_k",
            f"lora_new_A_v",
            f"lora_new_B_k",
            f"lora_new_B_v",
        ]
        print("Parameters to be updated:")
        for name, param in self.network.named_parameters():
            param.requires_grad_(any(key in name for key in unfrozen_keys))
            if param.requires_grad:
                print(f"[Trainable] {name}")
