import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu, avg_pool2d
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms

import os
import os.path
from collections import OrderedDict
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sn
import pandas as pd
import random
import pdb
import argparse,time
import math
from copy import deepcopy

## Define ResNet18 model
def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1):
    return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1))

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
def conv7x7(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=7, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, track_running_stats=False)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, track_running_stats=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes, track_running_stats=False)
            )
        self.act = OrderedDict()
        self.count = 0

    def forward(self, x):
        self.count = self.count % 2 
        self.act['conv_{}'.format(self.count)] = x
        self.count +=1
        out = relu(self.bn1(self.conv1(x)))
        self.count = self.count % 2 
        self.act['conv_{}'.format(self.count)] = out
        self.count +=1
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, taskcla, nf):
        super(ResNet, self).__init__()
        self.in_planes = nf
        self.conv1 = conv3x3(3, nf * 1, 2)
        self.bn1 = nn.BatchNorm2d(nf * 1, track_running_stats=False)
        self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
        
        self.taskcla = taskcla
        self.linear=torch.nn.ModuleList()
        for t, n in self.taskcla:
            self.linear.append(nn.Linear(nf * 8 * block.expansion * 9, n, bias=False))
        self.act = OrderedDict()

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        bsz = x.size(0)
        self.act['conv_in'] = x.view(bsz, 3, 84, 84)
        out = relu(self.bn1(self.conv1(x.view(bsz, 3, 84, 84))))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        y=[]
        for t,i in self.taskcla:
            y.append(self.linear[t](out))
        return y

def ResNet18(taskcla, nf=32):
    return ResNet(BasicBlock, [2, 2, 2, 2], taskcla, nf)

def get_model(model):
    return deepcopy(model.state_dict())

def set_model_(model,state_dict):
    model.load_state_dict(deepcopy(state_dict))
    return

def adjust_learning_rate(optimizer, epoch, args):
    for param_group in optimizer.param_groups:
        if (epoch ==1):
            param_group['lr']=args.lr
        else:
            param_group['lr'] /= args.lr_factor  

def train(args, model, device, x,y, optimizer,criterion, task_id):
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        x, y = x.to(device), y.to(device)
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()        
        output = model(data)
        loss = criterion(output[task_id], target)        
        loss.backward()
        optimizer.step()

def train_projected(args,model,device,x,y,optimizer,criterion,feature_mat, feature_mat_weight,task_id):
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        x, y = x.to(device), y.to(device)
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()        
        output = model(data)
        loss = criterion(output[task_id], target)

        # ---- Gradient-projection regulariser (onto protected subspace) ----
        if args.lambda_gp_ql > 0 and task_id > 0:

            grad_params = []
            U_layers = []
            U_layers_weight = []
            kk_tmp = 0
            for k_param, (name, p) in enumerate(model.named_parameters()):
                if 'weight' in name and len(p.size()) == 4:
                    grad_params.append(p)
                    U_layers.append(feature_mat[kk_tmp])
                    U_layers_weight.append(feature_mat_weight[kk_tmp])
                    kk_tmp += 1

            if grad_params:
                g_list = torch.autograd.grad(loss, grad_params,
                                             create_graph=True, retain_graph=True, allow_unused=True)
                g_list = [g if g is not None else torch.zeros_like(p)
                          for g, p in zip(g_list, grad_params)]

                proj_reg = 0.0
                for g, U in zip(g_list, U_layers):
                    g_flat = g.view(g.size(0), -1)
                    proj_g = torch.mm(g_flat, U)
                    proj_reg += proj_g.pow(2).mean()
                proj_reg_weight = 0.0


                loss = loss + (task_id*0.01 + 1) * args.lambda_gp_ql * proj_reg
        loss.backward()
        # Gradient Projections 
        kk = 0 
        for k, (m,params) in enumerate(model.named_parameters()):
            if len(params.size())==4:
                sz =  params.grad.data.size(0)
                params.grad.data = params.grad.data - torch.mm(params.grad.data.view(sz,-1),\
                                                    feature_mat_weight[kk]).view(params.size())
                kk+=1
            elif len(params.size())==1 and task_id !=0:
                params.grad.data.fill_(0)

        optimizer.step()

def test(args, model, device, x, y, criterion, task_id):
    model.eval()
    total_loss = 0
    total_num = 0 
    correct = 0
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    with torch.no_grad():
        # Loop batches
        for i in range(0,len(r),args.batch_size_test):
            if i+args.batch_size_test<=len(r): b=r[i:i+args.batch_size_test]
            else: b=r[i:]
            x, y = x.to(device), y.to(device)
            data = x[b]
            data, target = data.to(device), y[b].to(device)
            output = model(data)
            loss = criterion(output[task_id], target)
            pred = output[task_id].argmax(dim=1, keepdim=True) 
            
            correct    += pred.eq(target.view_as(pred)).sum().item()
            total_loss += loss.data.cpu().numpy().item()*len(b)
            total_num  += len(b)

    acc = 100. * correct / total_num
    final_loss = total_loss / total_num
    return final_loss, acc

def compute_fisher_diag(model,
                        criterion,
                        inputs,
                        targets,
                        task_id: int = 0,
                        batch_size: int = 64,
                        n_batches: int = 20,
                        wanted_layers=None,
                        device='cuda'):

    model = model.to(device)
    model.eval()


    if wanted_layers is None:
        params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    else:
        params = {n: p for n, p in model.named_parameters()
                  if n in wanted_layers and p.requires_grad}


    fisher_acc = {n: torch.zeros_like(p, device=device) for n, p in params.items()}


    ds = TensorDataset(inputs, targets)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)


    cnt = 0
    with torch.enable_grad():
        for x_mb, y_mb in loader:
            x_mb, y_mb = x_mb.to(device), y_mb.to(device)

            model.zero_grad(set_to_none=True)
            out = model(x_mb)
            # model may return a list of heads; pick the one for current task
            if isinstance(out, (list, tuple)):
                out = out[task_id]
            loss = criterion(out, y_mb)
            loss.backward()


            for n, p in params.items():
                if p.grad is not None:
                    fisher_acc[n] += (p.grad.detach() ** 2)
            cnt += 1
            if cnt >= n_batches:
                break


    fisher_diag = {n: acc / n_batches for n, acc in fisher_acc.items()}
    return fisher_diag

def get_representation_matrix_ResNet18 (net, device, x, y=None): 
    # Collect activations by forward pass
    net.eval()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    b=r[0:100] # ns=100 examples
    x, y = x.to(device), y.to(device)
    example_data = x[b]
    example_data = example_data.to(device)
    example_out = net(example_data)

    
    act_list =[]
    act_list.extend([net.act['conv_in'], 
        net.layer1[0].act['conv_0'], net.layer1[0].act['conv_1'], net.layer1[1].act['conv_0'], net.layer1[1].act['conv_1'],
        net.layer2[0].act['conv_0'], net.layer2[0].act['conv_1'], net.layer2[1].act['conv_0'], net.layer2[1].act['conv_1'],
        net.layer3[0].act['conv_0'], net.layer3[0].act['conv_1'], net.layer3[1].act['conv_0'], net.layer3[1].act['conv_1'],
        net.layer4[0].act['conv_0'], net.layer4[0].act['conv_1'], net.layer4[1].act['conv_0'], net.layer4[1].act['conv_1']])

    batch_list = [10, 10, 10, 10, 10, 10, 10, 10, 50, 50, 50, 100, 100, 100, 100, 100, 100]  # scaled


    stride_list = [2, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1]
    map_list = [84, 42, 42, 42, 42, 42, 21, 21, 21, 21, 11, 11, 11, 11, 6, 6, 6]
    in_channel  = [ 3, 20,20,20,20, 20,40,40,40, 40,80,80,80, 80,160,160,160]

    pad = 1
    sc_list=[5,9,13]
    p1d = (1, 1, 1, 1)
    mat_final=[] # list containing GPM Matrices 
    mat_list=[]
    mat_sc_list=[]
    for i in range(len(stride_list)):
        if i==0:
            ksz = 3
        else:
            ksz = 3 
        bsz=batch_list[i]
        st = stride_list[i]     
        k=0
        s=compute_conv_output_size(map_list[i],ksz,stride_list[i],pad)
        mat = np.zeros((ksz*ksz*in_channel[i],s*s*bsz))
        act = F.pad(act_list[i], p1d, "constant", 0).detach().cpu().numpy()
        for kk in range(bsz):
            for ii in range(s):
                for jj in range(s):
                    mat[:,k]=act[kk,:,st*ii:ksz+st*ii,st*jj:ksz+st*jj].reshape(-1)
                    k +=1
        mat_list.append(mat)

        if i in sc_list:
            k=0
            s=compute_conv_output_size(map_list[i],1,stride_list[i])
            mat = np.zeros((1*1*in_channel[i],s*s*bsz))
            act = act_list[i].detach().cpu().numpy()
            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:,k]=act[kk,:,st*ii:1+st*ii,st*jj:1+st*jj].reshape(-1)
                        k +=1
            mat_sc_list.append(mat) 

    ik=0
    for i in range (len(mat_list)):
        mat_final.append(mat_list[i])
        if i in [6,10,14]:
            mat_final.append(mat_sc_list[ik])
            ik+=1

    print('-'*30)
    print('Representation Matrix')
    print('-'*30)
    for i in range(len(mat_final)):
        print ('Layer {} : {}'.format(i+1,mat_final[i].shape))
    print('-'*30)
    return mat_final

def update_SGP(args, model, device, memory, task_name, mat_list, threshold, task_id, feature_list=[],
               importance_list=[]):

    # plt.figure(figsize=(10, 6))
    print('Threshold: ', threshold)
    if not feature_list:
        # After First Task
        for i in range(len(mat_list)):
            activation = mat_list[i]
            U, S, Vh = np.linalg.svd(activation, full_matrices=False)

            sval_total = (S ** 2).sum()
            sval_ratio = (S ** 2) / sval_total
            r = np.sum(np.cumsum(sval_ratio) < threshold[i])  # +1
            # update GPM
            feature_list.append(U[:, 0:r])
            # save into memory
            memory[task_name][str(i)]['space_list'] = torch.tensor(U[:, 0:r], dtype=torch.float32, device=device)

            # update importance (Eq-2)
            importance = ((args.scale_coff + 1) * S[0:r]) / (args.scale_coff * S[0:r] + max(S[0:r]))
            importance_list.append(importance)

    else:
        for i in range(len(mat_list)):
            activation = mat_list[i]

            delta = []
            R2 = np.dot(activation, activation.transpose())
            for ki in range(feature_list[i].shape[1]):
                space = feature_list[i].transpose()[ki]
                # print(space.shape)
                delta_i = np.dot(np.dot(space.transpose(), R2), space)
                # print(delta_i)
                delta.append(delta_i)
            delta = np.array(delta)

            U1, S1, Vh1 = np.linalg.svd(activation, full_matrices=False)
            sval_total = (S1 ** 2).sum()

            act_proj = np.dot(np.dot(feature_list[i], feature_list[i].transpose()), activation)
            r_old = feature_list[i].shape[1]  # old GPM bases
            Uc, Sc, Vhc = np.linalg.svd(act_proj, full_matrices=False)
            importance_new_on_old = np.dot(np.dot(feature_list[i].transpose(), Uc[:, 0:r_old]) ** 2,
                                           Sc[0:r_old] ** 2)  ## r_old no of elm s**2 fmt
            importance_new_on_old = np.sqrt(importance_new_on_old)

            act_hat = activation - act_proj
            U, S, Vh = np.linalg.svd(act_hat, full_matrices=False)

            sval_hat = (S ** 2).sum()
            sval_ratio = (S ** 2) / sval_total
            accumulated_sval = (sval_total - sval_hat) / sval_total
            sigma = S ** 2

            # =3 stack delta and sigma in a same list, then sort in descending order
            stack = np.hstack((delta, sigma))
            stack_index = np.argsort(stack)[::-1]  # [99, 0, 4,7...]
            # print('stack index:{}'.format(stack_index))
            stack = np.sort(stack)[::-1]

            # =4 select the most import basis
            r1_pre = len(delta)
            r1 = 0
            accumulated_sval1 = 0
            for ii1 in range(len(stack)):
                if accumulated_sval1 < threshold[i] * sval_total:
                    accumulated_sval1 += stack[ii1]
                    r1 += 1
                    if r1 == activation.shape[0]:
                        break
                else:
                    break

            # =5 save the corresponding space
            U_list = np.hstack((feature_list[i], U))
            sel_index = stack_index[:r1]
            memory[task_name][str(i)]['space_list'] = torch.tensor(U_list[:, sel_index], dtype=torch.float32,
                                                                   device=device)

            r = 0
            for ii in range(sval_ratio.shape[0]):
                if accumulated_sval < threshold[i]:
                    accumulated_sval += sval_ratio[ii]
                    r += 1
                else:
                    break
            if r == 0:
                print('Skip Updating GPM for layer: {}'.format(i + 1))
                # update importances
                importance = importance_new_on_old
                importance = ((args.scale_coff + 1) * importance) / (args.scale_coff * importance + max(importance))
                importance[0:r_old] = np.clip(importance[0:r_old] + importance_list[i][0:r_old], 0, 1)
                importance_list[i] = importance  # update importance
                continue
            # update GPM
            Ui = np.hstack((feature_list[i], U[:, 0:r]))
            # update importance
            importance = np.hstack((importance_new_on_old, S[0:r]))
            importance = ((args.scale_coff + 1) * importance) / (args.scale_coff * importance + max(importance))
            importance[0:r_old] = np.clip(importance[0:r_old] + importance_list[i][0:r_old], 0, 1)

            if Ui.shape[1] > Ui.shape[0]:
                feature_list[i] = Ui[:, 0:Ui.shape[0]]
                importance_list[i] = importance[0:Ui.shape[0]]
            else:
                feature_list[i] = Ui
                importance_list[i] = importance

    print('-' * 40)
    print('Gradient Constraints Summary')
    print('-' * 40)
    for i in range(len(feature_list)):
        print('Layer {} : {}/{}'.format(i + 1, feature_list[i].shape[1], feature_list[i].shape[0]))
    print('-' * 40)
    return feature_list, importance_list

def importance_loss(model, memory, device, task_name_list, task_id, feature_list, importance_list):
    importance_all = []
    kk = 0
    for k, (pname, param) in enumerate(model.named_parameters()):
        if 'weight' in pname and len(param.size()) == 4:
            F, r = feature_list[kk].shape
            layer_scores = []

            for f in range(r):
                # Δw which like (F,1)
                basis = torch.tensor(feature_list[kk][:, f], dtype=torch.float32).to(device).unsqueeze(1)  # ←❷
                quad_acc = 0.0

                for t_idx in range(task_id+1):
                    slot = memory[task_name_list[t_idx]][str(kk)]
                    U = slot['space_list']
                    H = slot['hess_diag'].to(device)  # [F,1]

                    proj = U @ (U.t() @ basis)
                    proj2 = proj.pow(2)

                    out_ch = param.shape[0]
                    F_dim = proj2.size(0)
                    proj_full = proj2.t().expand(out_ch, F_dim)

                    H_full = H.view(out_ch, F_dim)

                    quad_acc += 0.5 * (H_full * proj_full).sum()

                layer_scores.append( quad_acc / max(1, task_id + 1) )  # ←❹


            scores = torch.tensor(layer_scores, dtype=torch.float32, device=device)  # [r]

            scores_norm = ((args.scale_coff_Taylor + 1) * scores) / (args.scale_coff_Taylor * scores + max(scores))

            importance_all.append(scores_norm.cpu())
            kk += 1

    return importance_all



def main(args):
    tstart=time.time()
    ## Device Setting 
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True # unlock this and following code to set a stable seed
    # torch.set_float32_matmul_precision('high')
    ## Load CIFAR100 DATASET
    from dataloader import miniimagenet2 as data_loader
    dataloader = data_loader.DatasetGen(args)
    taskcla, inputsize = dataloader.taskcla, dataloader.inputsize

    acc_matrix=np.zeros((20,20))
    criterion = torch.nn.CrossEntropyLoss()

    task_id = 0
    task_list = []
    task_name_list = []
    memory = {}
    acc_list_all = []

    epochs_back = []
    for k,ncla in taskcla:
        # specify threshold hyperparameter
        threshold = np.array([0.985] * 20) + task_id * np.array([0.00075] * 20)
        data = dataloader.get(k)
        task_name = data[k]['name']
        task_name_list.append(task_name)
     
        print('*'*100)
        print('Task {:2d} ({:s})'.format(k,data[k]['name']))
        print('*'*100)
        xtrain=data[k]['train']['x']
        ytrain=data[k]['train']['y']
        xvalid=data[k]['valid']['x']
        yvalid=data[k]['valid']['y']
        xtest =data[k]['test']['x']
        ytest =data[k]['test']['y']
        task_list.append(k)

        lr = args.lr 
        best_loss=np.inf
        print ('-'*40)
        print ('Task ID :{} | Learning Rate : {}'.format(task_id, lr))
        print ('-'*40)
        
        if task_id==0:
            model = ResNet18(taskcla,20).to(device) # base filters: 20
            memory[task_name] = {}

            # print_log('Model parameters ---', log)
            kk = 0
            for k_t, (m, param) in enumerate(model.named_parameters()):
                if 'weight' in m and 'bn' not in m:
                    # print_log((k_t, m, param.shape), log)
                    memory[task_name][str(kk)] = {
                        'space_list': {},
                        'grad_list': {},
                        'regime': {},
                    }
                    kk += 1
            best_model=get_model(model)
            feature_list =[]
            importance_list = []
            optimizer = optim.SGD(model.parameters(), lr=lr)

            for epoch in range(1, args.n_epochs+1):
                # Train
                clock0=time.time()
                train(args, model, device, xtrain, ytrain, optimizer, criterion, k)
                clock1=time.time()
                tr_loss,tr_acc = test(args, model, device, xtrain, ytrain,  criterion, k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                            tr_loss,tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid,  criterion, k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test
            print ('-'*40)
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion, k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))
            # Memory Update  
            mat_list = get_representation_matrix_ResNet18 (model, device, xtrain, ytrain)
            # feature_list = update_GPM (model, mat_list, threshold, feature_list)
            feature_list, importance_list = update_SGP(args, model, device, memory, task_name, mat_list,
                                                       threshold, task_id, feature_list, importance_list)


            wanted = {name for name, p in model.named_parameters()
                        if len(p.size())==4 and 'weight' in name}
            print(wanted)

            fisher = compute_fisher_diag(model, criterion,
                                         inputs=xtrain,
                                         targets=ytrain,
                                         batch_size=64,
                                         n_batches=80,
                                         wanted_layers=wanted,
                                         device=device,
                                         task_id=task_id)


            print(fisher['conv1.weight'].shape)
            for kk_h, (pname, fdiag) in enumerate(fisher.items()):
                memory[task_name][str(kk_h)]['hess_diag'] = fdiag.clone().cpu()
            importance_all = importance_loss(model, memory, device, task_name_list, task_id, feature_list,
                                             importance_list)


        else:
            memory[task_name] = {}
            kk = 0

            for k_t, (m, params) in enumerate(model.named_parameters()):
                # create the saved memory
                if 'weight' in m and 'bn' not in m:
                    memory[task_name][str(kk)] = {
                        'space_list': {},
                        'grad_list': {},
                        'space_mat_list': {},
                        'scale1': {},
                        'scale2': {},
                        'regime': {},
                        'selected_task': {},
                        'ratio': {},
                    }
                    kk += 1
            optimizer = optim.SGD(model.parameters(), lr=args.lr)
            feature_mat = []
            feature_mat_weight = []

            for i in range(len(feature_list)):
                Uf = torch.Tensor(np.dot(feature_list[i], np.dot(np.diag(importance_all[i]), feature_list[i].transpose()))).to(device)

                Uf.requires_grad = False
                print('Layer {} - Projection Matrix shape: {}'.format(i+1,Uf.shape))

                feature_mat.append(Uf)
            for i in range(len(feature_list)):
                Uf2 = torch.Tensor(
                np.dot(feature_list[i], np.dot(np.diag(importance_list[i]), feature_list[i].transpose()))).to(device)

                Uf2.requires_grad = False
                feature_mat_weight.append(Uf2)
            print ('-'*40)
            for epoch in range(1, args.n_epochs+1):
                # Train 
                clock0=time.time()
                train_projected(args, model,device,xtrain, ytrain,optimizer,criterion,feature_mat, feature_mat_weight,k)
                clock1=time.time()
                tr_loss, tr_acc = test(args, model, device, xtrain, ytrain,criterion,k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                        tr_loss, tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid, criterion,k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion,k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))
            # Memory Update
            mat_list = get_representation_matrix_ResNet18 (model, device, xtrain, ytrain)

            feature_list, importance_list = update_SGP(args, model, device, memory, task_name, mat_list,
                                                       threshold, task_id, feature_list, importance_list)

            wanted = {name for name, p in model.named_parameters()
                      if len(p.size()) == 4 and 'weight' in name}
            fisher = compute_fisher_diag(model, criterion,
                                         inputs=xtrain,
                                         targets=ytrain,
                                         batch_size=64,
                                         n_batches=80,
                                         wanted_layers=wanted,
                                         device=device,
                                         task_id=task_id)


            print(fisher['conv1.weight'].shape)
            for kk_h, (pname, fdiag) in enumerate(fisher.items()):
                memory[task_name][str(kk_h)]['hess_diag'] = fdiag.clone().cpu()
            importance_all = importance_loss(model, memory, device, task_name_list, task_id, feature_list,
                                             importance_list)
        

        jj = 0 
        for ii in np.array(task_list)[0:task_id+1]:
            xtest =data[ii]['test']['x']
            ytest =data[ii]['test']['y']
            test_acc_sum = 0
            for i_5 in range(5):
                test_loss, test_acc = test(args, model, device, xtest, ytest,criterion,ii)
                test_acc_sum += test_acc
            acc_matrix[task_id,jj] =test_acc_sum/5

            jj +=1
        print('Accuracies =')
        for i_a in range(task_id+1):
            print('\t',end='')
            for j_a in range(acc_matrix.shape[1]):
                print('{:5.1f}% '.format(acc_matrix[i_a,j_a]),end='')
            print()
        # update task id 
        task_id +=1
    print('-'*50)
    # Simulation Results 
    print ('Task Order : {}'.format(np.array(task_list)))
    print ('Final Avg Accuracy: {:5.2f}%'.format(acc_matrix[-1].mean())) 
    bwt=np.mean((acc_matrix[-1]-np.diag(acc_matrix))[:-1]) 
    print ('Backward transfer: {:5.2f}%'.format(bwt))
    print('[Elapsed time = {:.1f} ms]'.format((time.time()-tstart)*1000))
    print('-'*50)
    # Plots
    array = acc_matrix
    df_cm = pd.DataFrame(array, index=[i for i in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19","20"]],
                         columns=[i for i in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19",
                                              "20"]])
    sn.set(font_scale=1.4)
    sn.heatmap(df_cm, annot=True, annot_kws={"size": 10})
    plt.show()

if __name__ == "__main__":
    # Training parameters
    parser = argparse.ArgumentParser(description='5 datasets with GPM')
    parser.add_argument('--batch_size_train', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--batch_size_test', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--n_epochs', type=int, default=100, metavar='N',
                        help='number of training epochs/task (default: 200)')
    parser.add_argument('--seed', type=int, default=37, metavar='S',
                        help='random seed (default: 37)')
    parser.add_argument('--pc_valid',default=0.02,type=float,
                        help='fraction of training data used for validation')
    # Optimizer parameters
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--lr_min', type=float, default=1e-3, metavar='LRM',
                        help='minimum lr rate (default: 5e-5)')
    parser.add_argument('--lr_patience', type=int, default=5, metavar='LRP', #5
                        help='hold before decaying lr (default: 6)')
    parser.add_argument('--lr_factor', type=int, default=3, metavar='LRF', # 3
                        help='lr decay factor (default: 5)')
    # SGP/GPM
    parser.add_argument('--scale_coff', type=int, default=5, metavar='SCF', # 5
                        help='importance co-efficeint (default: 10)')
    parser.add_argument('--scale_coff_Taylor', type=int, default=5, metavar='SCF', # 5
                        help='importance co-efficeint (default: 10)')
    parser.add_argument('--gpm_eps', type=float, default=0.985, metavar='EPS',
                        help='threshold (default: 0.97)')
    parser.add_argument('--gpm_eps_inc', type=float, default=0.003, metavar='EPSI',
                        help='threshold increment per task (default: 0.003)')
    # Hessian-regularization
    parser.add_argument('--lambda_gp_ql', type=float, default=1e2, metavar='LG',
                        help='coefficient for first‑order gradient flatness regularization')

    args = parser.parse_args()
    print('='*100)
    print('Arguments =')
    for arg in vars(args):
        print('\t'+arg+':',getattr(args,arg))
    print('='*100)

    main(args)



