import os
import sys
import torch
import logging
import numpy as np
from tqdm import tqdm


def count_parameters(model, trainable=False):
    if trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def tensor2numpy(x):
    return x.cpu().data.numpy() if x.is_cuda else x.data.numpy()


def target2onehot(targets, n_classes):
    onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device)
    onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.)
    return onehot


def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)


def make_logdir(args):
    logdir = 'logs/{}-{}-{}t'.format(args['dataset'], args['model_name'], args['total_sessions'])
    if args['debug']:
        logdir = os.path.join(logdir, 'debug')
    makedirs(logdir)
    return logdir


def setup_logging(logfilename, no_ckp=False):
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    if not no_ckp:
        makedirs(logfilename)
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s [%(filename)s] => %(message)s',
        handlers=[
            logging.FileHandler(filename=logfilename + '.log'),
            logging.StreamHandler(sys.stdout)])


def set_device(args):
    device_type = args['device']
    gpus = []
    for device in device_type:
        if device_type == -1:
            device = torch.device('cpu')
        else:
            device = torch.device('cuda:{}'.format(device))
        gpus.append(device)
    args['device'] = gpus


def set_random(args):
    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed(args['seed'])
    torch.cuda.manual_seed_all(args['seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def print_args(args):
    for key, value in args.items():
        logging.info('{}: {}'.format(key, value))


def format_elapsed_time(start_time, end_time):
    elapsed_time = end_time - start_time
    hours = int(elapsed_time // 3600)
    minutes = int((elapsed_time % 3600) // 60)
    seconds = int(elapsed_time % 60)
    return '{}h {}m {}s'.format(hours, minutes, seconds)


def split_images_labels(imgs):
    # split trainset.imgs in ImageFolder
    images = []
    labels = []
    for item in imgs:
        images.append(item[0])
        labels.append(item[1])

    return np.array(images), np.array(labels)


def accuracy(y_pred, y_true, known_classes, increment=10):
    assert len(y_pred) == len(y_true), 'Data length error.'
    all_acc = {}
    all_acc['total'] = np.around((y_pred == y_true).sum()*100 / len(y_true), decimals=2)

    # Grouped accuracy
    for class_id in range(0, np.max(y_true), increment):
        idxes = np.where(np.logical_and(y_true >= class_id, y_true < class_id + increment))[0]
        label = '{}-{}'.format(str(class_id).rjust(2, '0'), str(class_id+increment-1).rjust(2, '0'))
        all_acc[label] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2)

    # Old accuracy
    idxes = np.where(y_true < known_classes)[0]
    all_acc['old'] = 0 if len(idxes) == 0 else np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes),
                                                         decimals=2)

    # New accuracy
    idxes = np.where(y_true >= known_classes)[0] 
    all_acc['new'] = np.around((y_pred[idxes] == y_true[idxes]).sum()*100 / len(idxes), decimals=2)

    return all_acc


def accuracy_all(y_pred, y_true):
    assert len(y_pred) == len(y_true), 'Data length error.'

    if hasattr(y_pred, 'cpu'):
        y_pred = y_pred.cpu().numpy()
    if hasattr(y_true, 'cpu'):
        y_true = y_true.cpu().numpy()

    correct = (y_pred == y_true).sum()
    all_acc = np.around(correct * 100 / len(y_true), decimals=2)
    return all_acc


class FisherComputer:
    def __init__(self, task_id, network, dataloader, increment, criterion, device=torch.device('cpu')):
        self.model = network.to(device)
        self.dataloader = dataloader
        self.increment = increment
        self.criterion = criterion
        self.device = device

        self.task_id = task_id
        self.fisher_W = []
        self._init_fisher_storage()

    def compute(self, max_batches=None):
        self.model.eval()
        num_samples = 0
        
        for i, (_, inputs, targets) in enumerate(tqdm(self.dataloader, desc="Computing Fisher")):
            if max_batches and i >= max_batches:
                break
            # Empirical Fisher
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.model.zero_grad()
            logits = self.model.forward(inputs, use_new=True, register_hook=True)['logits']
            targets = targets - self.task_id * self.increment
            loss = self.criterion(logits, targets)
            loss.backward()

            batch_size = inputs.size(0)
            num_samples += batch_size

            idx = 0
            for module in self.model.modules():
                if hasattr(module, 'delta_w_k_new_grad'):
                    grad_k = module.delta_w_k_new_grad
                    if grad_k is not None:
                        self.fisher_W[idx] += (grad_k.detach() ** 2) * batch_size
                    idx += 1
                if hasattr(module, 'delta_w_v_new_grad'):
                    grad_v = module.delta_w_v_new_grad
                    if grad_v is not None:
                        self.fisher_W[idx] += (grad_v.detach() ** 2) * batch_size
                    idx += 1                    
        self.fisher_W = [fw / num_samples for fw in self.fisher_W]

        return self.fisher_W
    
    def _init_fisher_storage(self):
        for module in self.model.modules():
            if hasattr(module, 'lora_new_B_k') and hasattr(module, 'lora_new_A_k'):
                delta_w_k_new = module.lora_new_B_k.weight @ module.lora_new_A_k.weight
                self.fisher_W.append(torch.zeros_like(delta_w_k_new))  # Key
            if hasattr(module, 'lora_new_B_v') and hasattr(module, 'lora_new_A_v'):
                delta_w_v_new = module.lora_new_B_v.weight @ module.lora_new_A_v.weight
                self.fisher_W.append(torch.zeros_like(delta_w_v_new))  # Value


def _solve_sylvester_cg(B, A, GB, GA, eps=1e-6, tol=1e-6, maxiter=200, verbose=False):
    """
    (B B^T) G + G (A^T A) = GB A + B GA
    B: (m, r)
    A: (r, n)
    GB: (m, r)
    GA: (r, n)
    """
    m, n = B.shape[0], A.shape[1]
    R = GB @ A + B @ GA
    mn = m * n

    def matvec(vec):
        G = vec.view(m, n)
        MG = B @ (B.T @ G)
        GN = (G @ A.T) @ A
        out = MG + GN
        if eps != 0.0:
            out = out + eps * G
        return out.reshape(mn)

    b = R.reshape(mn)

    x_vec = torch.zeros_like(b)
    r_vec = b - matvec(x_vec)
    p = r_vec.clone()
    rsold = torch.dot(r_vec, r_vec)

    for k in range(maxiter):
        Ap = matvec(p)
        alpha = rsold / (torch.dot(p, Ap) + 1e-30)
        x_vec = x_vec + alpha * p
        r_vec = r_vec - alpha * Ap
        rsnew = torch.dot(r_vec, r_vec)
        if verbose:
            print(f"iter={k}, residual={rsnew.sqrt().item():.3e}")
        if torch.sqrt(rsnew) <= tol * torch.sqrt(torch.dot(b, b)):
            break
        beta = rsnew / (rsold + 1e-30)
        p = r_vec + beta * p
        rsold = rsnew

    return x_vec.view(m, n)
