import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms

import os
import os.path
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np

import random
import pdb
import argparse,time
import math
from copy import deepcopy

## Define LeNet model
lambda_nuc = [1e-5, 1e-5, 1e-5, 1e-5, 1e-5]
def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1):
    return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1))

class LeNet(nn.Module):
    def __init__(self,taskcla):
        super(LeNet, self).__init__()
        self.act=OrderedDict()
        self.map =[]
        self.ksize=[]
        self.in_channel =[]
        
        self.map.append(32)
        self.conv1 = nn.Conv2d(3, 20, 5, bias=False, padding=2)

        s=compute_conv_output_size(32,5,1,2)
        s=compute_conv_output_size(s,3,2,1)
        self.ksize.append(5)
        self.in_channel.append(3)        
        self.map.append(s)
        self.conv2 = nn.Conv2d(20, 50, 5, bias=False, padding=2)
        
        s=compute_conv_output_size(s,5,1,2)
        s=compute_conv_output_size(s,3,2,1)
        self.ksize.append(5)
        self.in_channel.append(20)        
        self.smid=s
        self.map.append(50*self.smid*self.smid)
        self.maxpool=torch.nn.MaxPool2d(3,2,padding=1)
        self.relu=torch.nn.ReLU()
        self.drop1=torch.nn.Dropout(0.01)
        self.drop2=torch.nn.Dropout(0.01)
        self.lrn = torch.nn.LocalResponseNorm(4,0.001/9.0,0.75,1)

        self.fc1 = nn.Linear(50*self.smid*self.smid,800, bias=False)
        self.fc2 = nn.Linear(800,500, bias=False)
        self.map.extend([800])
        
        self.taskcla = taskcla
        self.fc3=torch.nn.ModuleList()
        for t,n in self.taskcla:
            self.fc3.append(torch.nn.Linear(500,n,bias=False))
        
    def forward(self, x):
        bsz = deepcopy(x.size(0))
        self.act['conv1']=x
        x = self.conv1(x)
        x = self.maxpool(self.drop1(self.lrn(self.relu(x))))

        self.act['conv2']=x
        x = self.conv2(x)
        x = self.maxpool(self.drop1(self.lrn (self.relu(x))))

        x=x.reshape(bsz,-1)
        self.act['fc1']=x
        x = self.fc1(x)
        x = self.drop2(self.relu(x))

        self.act['fc2']=x        
        x = self.fc2(x)
        x = self.drop2(self.relu(x))

        y=[]
        for t,i in self.taskcla:
            y.append(self.fc3[t](x))
            
        return y

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')

def get_model(model):
    state_dict_copy = {}
    for k, v in model.state_dict().items():
        state_dict_copy[k] = v.detach().clone()
    return state_dict_copy

def set_model_(model,state_dict):
    model.load_state_dict(deepcopy(state_dict))
    return

def adjust_learning_rate(optimizer, epoch, args):
    for param_group in optimizer.param_groups:
        if (epoch ==1):
            param_group['lr']=args.lr
        else:
            param_group['lr'] /= args.lr_factor  

def train(args, model, device, x,y, optimizer,criterion, task_id, lr):
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        x, y = x.to(device), y.to(device)
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()        
        output = model(data)
        loss = criterion(output[task_id], target)
        # ---- First‑order flatness regularization ----
        wanted = {'conv1.weight', 'conv2.weight', 'fc1.weight', 'fc2.weight'}
        flat_reg = gradient_flatness_reg(loss, model, wanted)
        loss = loss + args.lambda_g * flat_reg
        loss.backward()
        optimizer.step()

def train_projected(args,model,memory, task_name_list, device,x,y,optimizer,criterion,feature_mat,feature_mat_weight, task_id, lr):
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)

    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        # pre_net = get_model(model)
        x, y = x.to(device), y.to(device)
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output[task_id], target)
        # ---- First‑order flatness regularization ----
        wanted = {'conv1.weight', 'conv2.weight', 'fc1.weight', 'fc2.weight'}
        flat_reg = gradient_flatness_reg(loss, model, wanted)
        loss = loss + args.lambda_g * flat_reg

        # ---- Gradient-projection regulariser (onto protected subspace) ----
        if args.lambda_gp_ql > 0 and task_id > 0:
            grad_params = []
            U_layers = []
            U_layers_weight = []
            kk_tmp = 0
            for k_param, (name, p) in enumerate(model.named_parameters()):
                if k_param < 4 and len(p.size()) != 1:
                    grad_params.append(p)
                    U_layers.append(feature_mat[kk_tmp])
                    U_layers_weight.append(feature_mat_weight[kk_tmp])
                    kk_tmp += 1

            if grad_params:
                g_list = torch.autograd.grad(loss, grad_params,
                                             create_graph=True, retain_graph=True, allow_unused=True)
                g_list = [g if g is not None else torch.zeros_like(p)
                          for g, p in zip(g_list, grad_params)]

                proj_reg = 0.0
                for g, U in zip(g_list, U_layers):
                    g_flat = g.reshape(g.size(0), -1)
                    proj_g = torch.mm(g_flat, U)
                    proj_reg += proj_g.pow(2).mean()
                proj_reg_weight = 0.0

                loss = loss + (task_id+1) * args.lambda_gp_ql * proj_reg

        loss.backward()

        # Gradient Projections
        kk = 0
        for k, (m, params) in enumerate(model.named_parameters()):
            if k < 4 and len(params.size()) != 1:
                sz = params.grad.data.size(0)
                params.grad.data = params.grad.data - torch.mm(params.grad.data.view(sz, -1), \
                                                               feature_mat_weight[kk]).view(params.size())
                kk += 1
            elif (k < 4 and len(params.size()) == 1) and task_id != 0:
                params.grad.data.fill_(0)

        optimizer.step()



def gradient_flatness_reg(loss, model, wanted_layers):

    params = [p for name, p in model.named_parameters() if name in wanted_layers]
    grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
    grads = [g if g is not None else torch.zeros_like(p) for g, p in zip(grads, params)]
    return sum((g ** 2).mean() for g in grads)


def test(args, model, device, x, y, criterion, task_id):
    model.eval()
    total_loss = 0
    total_num = 0 
    correct = 0
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    with torch.no_grad():
        # Loop batches
        for i in range(0,len(r),args.batch_size_test):
            if i+args.batch_size_test<=len(r): b=r[i:i+args.batch_size_test]
            else: b=r[i:]
            x, y = x.to(device), y.to(device)
            data = x[b]
            data, target = data.to(device), y[b].to(device)
            output = model(data)
            loss = criterion(output[task_id], target)
            pred = output[task_id].argmax(dim=1, keepdim=True) 
            
            correct    += pred.eq(target.view_as(pred)).sum().item()
            total_loss += loss.data.cpu().numpy().item()*len(b)
            total_num  += len(b)

    acc = 100. * correct / total_num
    final_loss = total_loss / total_num
    return final_loss, acc

def get_representation_matrix (net, device, x, y=None): 
    # Collect activations by forward pass
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    b=r[0:125] # Take 125 random samples
    x, y = x.to(device), y.to(device)
    example_data = x[b]
    example_data = example_data.to(device)

    batch_list=[2*12,100,125,125]
    pad = 2
    p1d = (2, 2, 2, 2)
    mat_list=[]
    act_key=list(net.act.keys())

    for i in range(len(net.map)):
        bsz=batch_list[i]
        k=0
        if i<2:
            ksz= net.ksize[i]
            s=compute_conv_output_size(net.map[i],net.ksize[i],1,pad)
            mat = np.zeros((net.ksize[i]*net.ksize[i]*net.in_channel[i],s*s*bsz))
            act = F.pad(net.act[act_key[i]], p1d, "constant", 0).detach().cpu().numpy()
         
            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:,k]=act[kk,:,ii:ksz+ii,jj:ksz+jj].reshape(-1) #?
                        k +=1
            mat_list.append(mat)
        else:
            act = net.act[act_key[i]].detach().cpu().numpy()
            activation = act[0:bsz].transpose()
            mat_list.append(activation)

    print('-'*30)
    print('Representation Matrix')
    print('-'*30)
    for i in range(len(mat_list)):
        print ('Layer {} : {}'.format(i+1,mat_list[i].shape))
    print('-'*30)
    return mat_list


def get_representation_and_gradient_Taylor(args, net, device, optimizer, criterion, task_name, memory, task_id, x,
                                           y=None):
    '''
    aim to get the representation (activation) and gradient(optimal) of each layer
    '''

    # Collect activations by forward pass
    steps = 1
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    b = r[0:125]  # Take 125 random samples
    x, y = x.to(device), y.to(device)

    example_data = x[b]
    example_data, target = example_data.to(device), y[b].to(device)

    batch_list = 3*[2 * 12, 100, 125, 125]
    pad = 2
    p1d = (2, 2, 2, 2)

    mat_list = []
    grad_list = []
    act_key = list(net.act.keys())

    if True:

        optimizer.zero_grad()
        example_out = net(example_data)
        loss = criterion(example_out[task_id], target)

        loss.backward()

        k_linear = 0
        for k, (m, params) in enumerate(net.named_parameters()):
            if 'weight' in m and 'bn' not in m:
                if len(params.shape) == 4:
                    grad = params.grad.data.detach().cpu().numpy()
                    grad = grad.reshape(grad.shape[0], grad.shape[1] * grad.shape[2] * grad.shape[3])
                    grad_list.append(grad)
                else:
                    if 'fc3' in m and k_linear == task_id:
                        grad = params.grad.data.detach().cpu().numpy()
                        grad_list.append(grad)
                        k_linear += 1
                    elif 'fc3' not in m:
                        grad = params.grad.data.detach().cpu().numpy()
                        grad_list.append(grad)

    for i in range(len(net.map)):
        bsz = batch_list[i]
        k = 0
        if i < 2:
            ksz = net.ksize[i]
            s = compute_conv_output_size(net.map[i], net.ksize[i], 1, pad)
            mat = np.zeros((net.ksize[i] * net.ksize[i] * net.in_channel[i], s * s * bsz))
            act = F.pad(net.act[act_key[i]], p1d, "constant", 0).detach().cpu().numpy()

            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:, k] = act[kk, :, ii:ksz + ii, jj:ksz + jj].reshape(-1)  # ?
                        k += 1
            mat_list.append(mat)
        else:
            act = net.act[act_key[i]].detach().cpu().numpy()
            activation = act[0:bsz].transpose()
            mat_list.append(activation)

    return mat_list, grad_list





from torch.utils.data import DataLoader, TensorDataset

def compute_fisher_diag(model,
                        criterion,
                        inputs,
                        targets,
                        task_id: int = 0,
                        batch_size: int = 64,
                        n_batches: int = 20,
                        wanted_layers=None,
                        device='cuda'):

    model = model.to(device)
    model.eval()


    if wanted_layers is None:
        params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    else:
        params = {n: p for n, p in model.named_parameters()
                  if n in wanted_layers and p.requires_grad}


    fisher_acc = {n: torch.zeros_like(p, device=device) for n, p in params.items()}


    ds = TensorDataset(inputs, targets)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)


    cnt = 0
    with torch.enable_grad():
        for x_mb, y_mb in loader:
            x_mb, y_mb = x_mb.to(device), y_mb.to(device)

            model.zero_grad(set_to_none=True)
            out = model(x_mb)

            if isinstance(out, (list, tuple)):
                out = out[task_id]
            loss = criterion(out, y_mb)
            loss.backward()


            for n, p in params.items():
                if p.grad is not None:
                    fisher_acc[n] += (p.grad.detach() ** 2)
            cnt += 1
            if cnt >= n_batches:
                break


    fisher_diag = {n: acc / n_batches for n, acc in fisher_acc.items()}
    return fisher_diag


def update_SGP_2(args, model, device, memory, task_name, mat_list, grad_list, threshold, task_id, feature_list=[],
               importance_list=[]):
    plt.figure(figsize=(10, 6))
    print('Threshold: ', threshold)
    if not feature_list:
        # After First Task
        for i in range(len(mat_list)):
            activation = mat_list[i]
            U, S, Vh = np.linalg.svd(activation, full_matrices=False)
            # criteria (Eq-1)
            sval_total = (S ** 2).sum()
            sval_ratio = (S ** 2) / sval_total
            r = np.sum(np.cumsum(sval_ratio) < threshold[i])  # +1

            feature_list.append(U[:, 0:r])

            memory[task_name][str(i)]['space_list'] = torch.tensor(U[:, 0:r], dtype=torch.float32, device=device)

            importance = ((args.scale_coff + 1) * S[0:r]) / (args.scale_coff * S[0:r] + max(S[0:r]))
            importance_list.append(importance)

    else:
        for i in range(len(mat_list)):
            activation = mat_list[i]
            gradient = grad_list[i]
            memory[task_name][str(i)]['grad_list'] = torch.tensor(gradient, dtype=torch.float32, device=device)

            delta = []
            R2 = np.dot(activation, activation.transpose())
            for ki in range(feature_list[i].shape[1]):
                space = feature_list[i].transpose()[ki]
                # print(space.shape)
                delta_i = np.dot(np.dot(space.transpose(), R2), space)
                # print(delta_i)
                delta.append(delta_i)
            delta = np.array(delta)

            U1, S1, Vh1 = np.linalg.svd(activation, full_matrices=False)
            sval_total = (S1 ** 2).sum()
            # Projected Representation (Eq-4)
            act_proj = np.dot(np.dot(feature_list[i], feature_list[i].transpose()), activation)
            r_old = feature_list[i].shape[1]  # old GPM bases
            Uc, Sc, Vhc = np.linalg.svd(act_proj, full_matrices=False)
            importance_new_on_old = np.dot(np.dot(feature_list[i].transpose(), Uc[:, 0:r_old]) ** 2,
                                           Sc[0:r_old] ** 2)  ## r_old no of elm s**2 fmt
            importance_new_on_old = np.sqrt(importance_new_on_old)

            act_hat = activation - act_proj
            U, S, Vh = np.linalg.svd(act_hat, full_matrices=False)
            # criteria (Eq-5)
            sval_hat = (S ** 2).sum()
            sval_ratio = (S ** 2) / sval_total
            accumulated_sval = (sval_total - sval_hat) / sval_total
            sigma = S ** 2

            # =3 stack delta and sigma in a same list, then sort in descending order
            stack = np.hstack((delta, sigma))  # [0,..30, 31..99]
            stack_index = np.argsort(stack)[::-1]  # [99, 0, 4,7...]
            # print('stack index:{}'.format(stack_index))
            stack = np.sort(stack)[::-1]

            # =4 select the most import basis
            r1_pre = len(delta)
            r1 = 0
            accumulated_sval1 = 0
            for ii1 in range(len(stack)):
                if accumulated_sval1 < threshold[i] * sval_total:
                    accumulated_sval1 += stack[ii1]
                    r1 += 1
                    if r1 == activation.shape[0]:
                        break
                else:
                    break

            # =5 save the corresponding space
            U_list = np.hstack((feature_list[i], U))
            sel_index = stack_index[:r1]
            memory[task_name][str(i)]['space_list'] = torch.tensor(U_list[:, sel_index], dtype=torch.float32, device=device)

            r = 0
            for ii in range(sval_ratio.shape[0]):
                if accumulated_sval < threshold[i]:
                    accumulated_sval += sval_ratio[ii]
                    r += 1
                else:
                    break
            if r == 0:
                print('Skip Updating GPM for layer: {}'.format(i + 1))
                # update importances
                importance = importance_new_on_old
                importance = ((args.scale_coff + 1) * importance) / (args.scale_coff * importance + max(importance))
                importance[0:r_old] = np.clip(importance[0:r_old] + importance_list[i][0:r_old], 0, 1)
                importance_list[i] = importance  # update importance
                continue
            # update GPM
            Ui = np.hstack((feature_list[i], U[:, 0:r]))
            # update importance
            importance = np.hstack((importance_new_on_old, S[0:r]))
            importance = ((args.scale_coff + 1) * importance) / (args.scale_coff * importance + max(importance))
            importance[0:r_old] = np.clip(importance[0:r_old] + importance_list[i][0:r_old], 0, 1)

            if Ui.shape[1] > Ui.shape[0]:
                feature_list[i] = Ui[:, 0:Ui.shape[0]]
                importance_list[i] = importance[0:Ui.shape[0]]
            else:
                feature_list[i] = Ui
                importance_list[i] = importance

    print('-' * 40)
    print('Gradient Constraints Summary')
    print('-' * 40)
    for i in range(len(feature_list)):
        print('Layer {} : {}/{}'.format(i + 1, feature_list[i].shape[1], feature_list[i].shape[0]))
    print('-' * 40)
    return feature_list, importance_list



def importance_loss(model, memory, device, task_name_list, task_id, feature_list, importance_list):
    importance_all = []
    kk = 0
    for k, (pname, param) in enumerate(model.named_parameters()):
        if k < 4 and len(param.size()) != 1:
            F, r = feature_list[kk].shape
            layer_scores = []

            for f in range(r):

                basis = torch.tensor(feature_list[kk][:, f], dtype=torch.float32).to(device).unsqueeze(1)  # ←❷
                quad_acc = 0.0

                for t_idx in range(task_id+1):
                    slot = memory[task_name_list[t_idx]][str(kk)]
                    U = slot['space_list']
                    H = slot['hess_diag'].to(device)

                    proj = U @ (U.t() @ basis)
                    proj2 = proj.pow(2)

                    out_ch = param.shape[0]
                    F_dim = proj2.size(0)
                    proj_full = proj2.t().expand(out_ch, F_dim)

                    H_full = H.view(out_ch, F_dim)

                    quad_acc += 0.5 * (H_full * proj_full).sum()

                layer_scores.append( quad_acc / max(1, task_id + 1) )


            scores = torch.tensor(layer_scores, dtype=torch.float32, device=device)

            scores_norm = ((args.scale_coff_Taylor + 1) * scores) / (args.scale_coff_Taylor * scores + max(scores))

            importance_all.append(scores_norm.cpu())
            kk += 1

    return importance_all




def main(args):
    tstart=time.time()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print (device)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False

    task_order = [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]),
                  np.array([15, 12, 5, 9, 7, 16, 18, 17, 1, 0, 3, 8, 11, 14, 10, 6, 2, 4, 13, 19]),
                  np.array([17, 1, 19, 18, 12, 7, 6, 0, 11, 15, 10, 5, 13, 3, 9, 16, 4, 14, 2, 8]),
                  np.array([11, 9, 6, 5, 12, 4, 0, 10, 13, 7, 14, 3, 15, 16, 8, 1, 2, 19, 18, 17]),
                  np.array([6, 14, 0, 11, 12, 17, 13, 4, 9, 1, 7, 19, 8, 10, 3, 15, 18, 5, 2, 16])]


    from dataloader import cifar100_superclass as data_loader
    data, taskcla = data_loader.cifar100_superclass_python(task_order[args.t_order], group=5, validation=True)
    test_data,_   = data_loader.cifar100_superclass_python(task_order[args.t_order], group=5)
    print (taskcla)

    acc_matrix=np.zeros((20,20))
    criterion = torch.nn.CrossEntropyLoss()
    task_list = []
    task_name_list = []
    memory = {}
    task_id = 0
    for k,ncla in taskcla:

        threshold = np.array([args.gpm_eps] * 4) + task_id*np.array([args.gpm_eps_inc] * 4)
     
        print('*'*100)
        print('Task {:2d} ({:s})'.format(k,data[k]['name']))
        print('*'*100)
        task_name = data[k]['name']+'-'+str(k)
        task_name_list.append(task_name)
        xtrain=data[k]['train']['x']
        ytrain=data[k]['train']['y']
        xvalid=data[k]['valid']['x']
        yvalid=data[k]['valid']['y']
        xtest =test_data[k]['test']['x']
        ytest =test_data[k]['test']['y']
        print(ytrain.shape,yvalid.shape,ytest.shape)
        task_list.append(k)

        lr = args.lr 
        best_loss=np.inf
        print ('-'*40)
        print ('Task ID :{} | Learning Rate : {}'.format(task_id, lr))
        print ('-'*40)
        
        if task_id==0:
            model = LeNet(taskcla).to(device)

            model.apply(init_weights)
            memory[task_name] = {}


            kk = 0
            for k_t, (m, param) in enumerate(model.named_parameters()):
                if 'weight' in m and 'bn' not in m:
                    # print_log((k_t, m, param.shape), log)
                    memory[task_name][str(kk)] = {
                        'space_list': {},
                        'grad_list': {},
                        'regime': {},
                    }
                    kk += 1

            best_model=get_model(model)
            feature_list =[]
            importance_list = []
            optimizer = optim.SGD(model.parameters(), lr=lr)

            for epoch in range(1, args.n_epochs+1):
                # Train
                clock0=time.time()
                train(args, model, device, xtrain, ytrain, optimizer, criterion, k, lr)
                clock1=time.time()
                tr_loss,tr_acc = test(args, model, device, xtrain, ytrain,  criterion, k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                            tr_loss,tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid,  criterion, k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test
            print ('-'*40)
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion, k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))
            mat_list, grad_list = get_representation_and_gradient_Taylor(args, model, device, optimizer, criterion,
                                                                         task_name, memory, k, xtrain, ytrain)
            feature_list, importance_list = update_SGP_2(args, model, device, memory, task_name, mat_list, grad_list,
                                                       threshold, task_id, feature_list, importance_list)

            wanted = {'conv1.weight',
                      'conv2.weight',
                      'fc1.weight',
                      'fc2.weight'}
            fisher = compute_fisher_diag(model, criterion,
                                         inputs=xtrain,
                                         targets=ytrain,
                                         batch_size=64,
                                         n_batches=80,
                                         wanted_layers=wanted,
                                         device=device,
                                         task_id=task_id)


            print(fisher['conv1.weight'].shape)
            for kk_h, (pname, fdiag) in enumerate(fisher.items()):
                memory[task_name][str(kk_h)]['hess_diag'] = fdiag.clone().cpu()
            importance_all = importance_loss(model, memory, device, task_name_list, task_id, feature_list,
                                             importance_list)


        else:
            memory[task_name] = {}
            kk = 0

            for k_t, (m, params) in enumerate(model.named_parameters()):

                if 'weight' in m and 'bn' not in m:
                    memory[task_name][str(kk)] = {
                        'space_list': {},
                        'grad_list': {},
                        'space_mat_list': {},
                        'scale1': {},
                        'scale2': {},
                        'regime': {},
                        'selected_task': {},
                        'ratio': {},
                    }
                    kk += 1

            optimizer = optim.SGD(model.parameters(), lr=args.lr)
            feature_mat = []
            feature_mat_weight = []

            for i in range(len(model.act)):
                Uf=torch.Tensor(np.dot(feature_list[i],np.dot(np.diag(importance_all[i]),feature_list[i].transpose()))).to(device)
                Uf.requires_grad = False

                feature_mat.append(Uf)

            for i in range(len(model.act)):
                Uf2 = torch.Tensor(
                    np.dot(feature_list[i], np.dot(np.diag(importance_list[i]), feature_list[i].transpose()))).to(
                    device)

                Uf2.requires_grad = False
                feature_mat_weight.append(Uf2)

            print ('-'*40)
            for epoch in range(1, args.n_epochs+1):
                # Train 
                clock0=time.time()
                train_projected(args, model, memory, task_name_list, device, xtrain, ytrain, optimizer, criterion,
                                feature_mat, feature_mat_weight, k, lr)
                clock1=time.time()
                tr_loss, tr_acc = test(args, model, device, xtrain, ytrain,criterion,k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                        tr_loss, tr_acc, 1000*(clock1-clock0)),end='')

                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid, criterion,k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')

                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test 
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion,k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))  

            mat_list, grad_list = get_representation_and_gradient_Taylor(args, model, device, optimizer, criterion,
                                                                         task_name, memory, task_id, xtrain, ytrain)

            feature_list, importance_list = update_SGP_2(args, model, device, memory, task_name, mat_list, grad_list,
                                                        threshold, task_id, feature_list, importance_list)
            wanted = {'conv1.weight',
                      'conv2.weight',
                      'fc1.weight',
                      'fc2.weight'}
            fisher = compute_fisher_diag(model, criterion,
                                         inputs=xtrain,
                                         targets=ytrain,
                                         batch_size=64,
                                         n_batches=80,
                                         wanted_layers=wanted,
                                         device=device,
                                         task_id=task_id)

            # 查看 conv1.weight 的对角
            print(fisher['conv1.weight'].shape)  # e.g. torch.Size([64, 3, 4, 4])
            for kk_h, (pname, fdiag) in enumerate(fisher.items()):
                memory[task_name][str(kk_h)]['hess_diag'] = fdiag.clone().cpu()
            print("len_fisher", len(fisher))
            importance_all = importance_loss(model, memory, device, task_name_list, task_id, feature_list,
                                             importance_list)

        # save accuracy
        jj = 0 
        for ii in task_order[args.t_order][0:task_id+1]:
            acc = 0.0
            xtest =test_data[ii]['test']['x']
            ytest =test_data[ii]['test']['y']
            for i in range(10):
                _, acc_s = test(args, model, device, xtest, ytest, criterion, ii)
                acc += acc_s
            acc_matrix[task_id, jj] = acc / 10
            jj +=1
        print('Accuracies =')
        for i_a in range(task_id+1):
            print('\t',end='')
            for j_a in range(acc_matrix.shape[1]):
                print('{:5.1f}% '.format(acc_matrix[i_a,j_a]),end='')
            print()
        task_id +=1
    print('-'*50)
    print ('Final Avg Accuracy: {:5.2f}%'.format(acc_matrix[-1].mean())) 
    bwt=np.mean((acc_matrix[-1]-np.diag(acc_matrix))[:-1]) 
    print ('Backward transfer: {:5.2f}%'.format(bwt))
    print('[Elapsed time = {:.1f} ms]'.format((time.time()-tstart)*1000))
    print('-'*50)


if __name__ == "__main__":
    # Training parameters
    parser = argparse.ArgumentParser(description='CIFAR-100 Superclass with GPM')
    parser.add_argument('--batch_size_train', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--batch_size_test', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--n_epochs', type=int, default=50, metavar='N',
                        help='number of training epochs/task (default: 200)')
    parser.add_argument('--seed', type=int, default=0, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--pc_valid',default=0.05,type=float,
                        help='fraction of training data used for validation')
    parser.add_argument('--t_order', type=int, default=0, metavar='TOD',
                        help='random seed (default: 0)')

    # Optimizer parameters
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--lr_min', type=float, default=1e-5, metavar='LRM',
                        help='minimum lr rate (default: 1e-5)')
    parser.add_argument('--lr_patience', type=int, default=6, metavar='LRP',
                        help='hold before decaying lr (default: 6)')
    parser.add_argument('--lr_factor', type=int, default=2, metavar='LRF',
                        help='lr decay factor (default: 2)')
    # SGP/GPM specific 
    parser.add_argument('--scale_coff', type=int, default=3, metavar='SCF',
                        help='scale co-efficeint (default: 10)')
    parser.add_argument('--scale_coff_Taylor', type=int, default=1e1, metavar='SCF',
                        help='importance co-efficeint (default: 1e1)')
    parser.add_argument('--gpm_eps', type=float, default=0.98, metavar='EPS',
                        help='threshold (default: 0.97)')
    parser.add_argument('--gpm_eps_inc', type=float, default=0.001, metavar='EPSI',
                        help='threshold increment per task (default: 0.003)')
    # Hessian-regularizationz
    parser.add_argument('--lambda_g', type=float, default=1e-1, metavar='LG',
                        help='coefficient for first‑order gradient flatness regularization')
    parser.add_argument('--lambda_gp_ql', type=float, default=1e-3, metavar='LG',
                        help='coefficient for first‑order gradient flatness regularization')
    args = parser.parse_args()
    print('='*100)
    print('Arguments =')
    for arg in vars(args):
        print('\t'+arg+':',getattr(args,arg))
    print('='*100)

    main(args)



