import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms

import os
import os.path
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np

import random
import pdb
import argparse,time
import math
from copy import deepcopy
Eplison_1 = 0.0
## Define AlexNet model
def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1):
    return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1))

class AlexNet(nn.Module):
    def __init__(self,taskcla):
        super(AlexNet, self).__init__()
        self.act=OrderedDict()
        self.map =[]
        self.ksize=[]
        self.in_channel =[]
        self.map.append(32)
        self.conv1 = nn.Conv2d(3, 64, 4, bias=False)
        self.bn1 = nn.BatchNorm2d(64, track_running_stats=False)
        s=compute_conv_output_size(32,4)
        s=s//2
        self.ksize.append(4)
        self.in_channel.append(3)
        self.map.append(s)
        self.conv2 = nn.Conv2d(64, 128, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(128, track_running_stats=False)
        s=compute_conv_output_size(s,3)
        s=s//2
        self.ksize.append(3)
        self.in_channel.append(64)
        self.map.append(s)
        self.conv3 = nn.Conv2d(128, 256, 2, bias=False)
        self.bn3 = nn.BatchNorm2d(256, track_running_stats=False)
        s=compute_conv_output_size(s,2)
        s=s//2
        self.smid=s
        self.ksize.append(2)
        self.in_channel.append(128)
        self.map.append(256*self.smid*self.smid)
        self.maxpool=torch.nn.MaxPool2d(2)
        self.relu=torch.nn.ReLU()
        self.drop1=torch.nn.Dropout(0.2)
        self.drop2=torch.nn.Dropout(0.5)

        self.fc1 = nn.Linear(256*self.smid*self.smid,2048, bias=False)
        self.bn4 = nn.BatchNorm1d(2048, track_running_stats=False)
        self.fc2 = nn.Linear(2048,2048, bias=False)
        self.bn5 = nn.BatchNorm1d(2048, track_running_stats=False)
        self.map.extend([2048])
        
        self.taskcla = taskcla
        self.fc3=torch.nn.ModuleList()
        for t,n in self.taskcla:
            self.fc3.append(torch.nn.Linear(2048,n,bias=False))
        
    def forward(self, x):
        bsz = deepcopy(x.size(0))
        self.act['conv1']=x
        x = self.conv1(x)
        x = self.maxpool(self.drop1(self.relu(self.bn1(x))))

        self.act['conv2']=x
        x = self.conv2(x)
        x = self.maxpool(self.drop1(self.relu(self.bn2(x))))

        self.act['conv3']=x
        x = self.conv3(x)
        x = self.maxpool(self.drop2(self.relu(self.bn3(x))))

        x=x.view(bsz,-1)
        self.act['fc1']=x
        x = self.fc1(x)
        x = self.drop2(self.relu(self.bn4(x)))

        self.act['fc2']=x        
        x = self.fc2(x)
        x = self.drop2(self.relu(self.bn5(x)))
        y=[]
        for t,i in self.taskcla:
            y.append(self.fc3[t](x))
            
        return y

# def get_model(model):
#     return deepcopy(model.state_dict())

def get_model(model):
    state_dict_copy = {}
    for k, v in model.state_dict().items():
        state_dict_copy[k] = v.detach().clone()
    return state_dict_copy

def set_model_(model,state_dict):
    model.load_state_dict(deepcopy(state_dict))
    return

def adjust_learning_rate(optimizer, epoch, args):
    for param_group in optimizer.param_groups:
        if (epoch ==1):
            param_group['lr']=args.lr
        else:
            param_group['lr'] /= args.lr_factor  

def train(args, model, device, x,y, optimizer,criterion, task_id, lr):
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    # Loop batches
    ct = len(r) // args.batch_size_train
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        x, y = x.to(device), y.to(device)
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()        
        output = model(data)
        loss = criterion(output[task_id], target)
        loss.backward()

        optimizer.step()

def train_projected(args,model,pre_model, memory, task_name_list, device,x,y,optimizer,criterion,feature_mat,feature_mat_weight, task_id, lr):
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    ct = len(r) // args.batch_size_train
    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        pre_net = get_model(model)
        x, y = x.to(device), y.to(device)
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()        
        output = model(data)
        loss = criterion(output[task_id], target)
        # ---- Gradient-projection regulariser (onto protected subspace) ----
        if args.lambda_gp_ql > 0 and task_id > 0:
            # 1) Choose protected feature layer
            grad_params = []
            U_layers = []
            U_layers_weight = []
            kk_tmp = 0
            for k_param, (name, p) in enumerate(model.named_parameters()):
                if 'weight' in name and 'bn' not in name and k_param < 15 and len(p.size()) != 1:  # 与后面硬投影同条件
                    grad_params.append(p)
                    U_layers.append(feature_mat[kk_tmp])
                    U_layers_weight.append(feature_mat_weight[kk_tmp])# 对应子空间矩阵 U
                    kk_tmp += 1

            if grad_params:  # 2) Compute gradients (retain graph)
                g_list = torch.autograd.grad(loss, grad_params,
                                             create_graph=True, retain_graph=True, allow_unused=True)
                g_list = [g if g is not None else torch.zeros_like(p)
                          for g, p in zip(g_list, grad_params)]
                # g_list2 = torch.autograd.grad(loss, grad_params,
                #                              create_graph=True, retain_graph=True, allow_unused=True)
                # g_list2 = [g if g is not None else torch.zeros_like(p)
                #           for g, p in zip(g_list2, grad_params)]

                proj_reg = 0.0
                for g, U in zip(g_list, U_layers):
                    g_flat = g.reshape(g.size(0), -1)  # [out, F]
                    proj_g = torch.mm(g_flat, U)
                    proj_reg += proj_g.pow(2).mean()

                # 3) Total loss
                loss = loss + args.lambda_gp_ql * proj_reg
        loss.backward()
        # Gradient Projections
        kk = 0
        for k, (m, params) in enumerate(model.named_parameters()):
            if k < 15 and len(params.size()) != 1:
                sz = params.grad.data.size(0)
                params.grad.data = params.grad.data - torch.mm(params.grad.data.view(sz, -1),  feature_mat_weight[kk]).view(params.size())
                kk += 1
            elif (k < 15 and len(params.size()) == 1) and task_id != 0:
                params.grad.data.fill_(0)

        optimizer.step()

def test(args, model, device, x, y, criterion, task_id):
    model.eval()
    total_loss = 0
    total_num = 0 
    correct = 0
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    with torch.no_grad():
        # Loop batches
        for i in range(0,len(r),args.batch_size_test):
            if i+args.batch_size_test<=len(r): b=r[i:i+args.batch_size_test]
            else: b=r[i:]
            x, y = x.to(device), y.to(device)
            data = x[b]
            data, target = data.to(device), y[b].to(device)
            output = model(data)
            loss = criterion(output[task_id], target)
            pred = output[task_id].argmax(dim=1, keepdim=True) 
            
            correct    += pred.eq(target.view_as(pred)).sum().item()
            total_loss += loss.data.cpu().numpy().item()*len(b)
            total_num  += len(b)

    acc = 100. * correct / total_num
    final_loss = total_loss / total_num
    return final_loss, acc


import torch
from torch.utils.data import DataLoader, TensorDataset

# @torch.no_grad()
def compute_fisher_diag(model,
                        criterion,
                        inputs,
                        targets,
                        task_id: int = 0,
                        batch_size: int = 64,
                        n_batches: int = 20,
                        wanted_layers=None,
                        device='cuda'):
    model = model.to(device)
    model.eval()


    if wanted_layers is None:
        params = {n: p for n, p in model.named_parameters() if p.requires_grad}
    else:
        params = {n: p for n, p in model.named_parameters()
                  if n in wanted_layers and p.requires_grad}


    fisher_acc = {n: torch.zeros_like(p, device=device) for n, p in params.items()}


    ds = TensorDataset(inputs, targets)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)


    cnt = 0
    with torch.enable_grad():
        for x_mb, y_mb in loader:
            x_mb, y_mb = x_mb.to(device), y_mb.to(device)

            model.zero_grad(set_to_none=True)
            out = model(x_mb)
            # model may return a list of heads; pick the one for current task
            if isinstance(out, (list, tuple)):
                out = out[task_id]
            loss = criterion(out, y_mb)
            loss.backward()


            for n, p in params.items():
                if p.grad is not None:
                    fisher_acc[n] += (p.grad.detach() ** 2)
            cnt += 1
            if cnt >= n_batches:
                break


    fisher_diag = {n: acc / n_batches for n, acc in fisher_acc.items()}
    return fisher_diag

def get_representation_and_gradient_Taylor (args, net, device, optimizer, criterion, task_name, memory, task_id, x, y=None):
    '''
    aim to get the representation (activation) and gradient(optimal) of each layer
    '''

    # Collect activations by forward pass
    steps = 1
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    b=r[0:125] # Take 125 random samples
    x, y = x.to(device), y.to(device)
    # compute_task_statistics(task_name, net, x, memory, device)
    example_data = x[b]
    example_data, target = example_data.to(device), y[b].to(device)

    batch_list= 3 * [2*12,100,100,125,125]
    mat_list=[]
    grad_list=[] # list contains gradient of each layer
    act_key=list(net.act.keys())
    grad_list_steps = []
    grad_list_avg = []
    # for i in range(steps):
    #     optimizer.zero_grad()
    if True:
        # compute_layerwise_hessians(net, criterion, example_data, target, device, task_id)
        # Compute gradients with create_graph for Hessian computation
        optimizer.zero_grad()
        example_out = net(example_data)
        loss = criterion(example_out[task_id], target)
        # First-order gradients
        # Only include actual layer weights, exclude batchnorm weights
        wanted = {'conv1.weight',
                  'conv2.weight',
                  'conv3.weight',
                  'fc1.weight',
                  'fc2.weight'}  # 举例，你按自己的实际层名来
        weight_params = [p for name, p in net.named_parameters()
                         if name in wanted]

        # grads = torch.autograd.grad(loss, weight_params, create_graph=True, allow_unused=True)
        # # Replace None gradients with zeros for unused parameters
        # grads = [g if g is not None else torch.zeros_like(p) for g, p in zip(grads, weight_params)]
        # # Estimate diagonal Hessian via Hutchinson's method
        # n_samples = 1
        # hess_diag_acc = [torch.zeros_like(p) for p in weight_params]
        # for _ in range(n_samples):
        #     # random Rademacher vectors
        #     vs = [torch.randint_like(p, low=0, high=2) * 2 - 1 for p in weight_params]
        #     inner = sum((g * v).sum() for g, v in zip(grads, vs))
        #     Hv = torch.autograd.grad(inner, weight_params, retain_graph=True, allow_unused=True)
        #     # Replace None Hessian entries with zeros
        #     Hv = [h if h is not None else torch.zeros_like(p) for h, p in zip(Hv, weight_params)]
        #     for idx, (hv, v) in enumerate(zip(Hv, vs)):
        #         hess_diag_acc[idx] += hv * v
        # hess_diag_list = [h / n_samples for h in hess_diag_acc]
        # Detach and save diagonal Hessians
        # kk_h = 0

        # print("len_of_hess_diag_", len(hess_diag_list))
        # for kk_h, hdiag in enumerate(hess_diag_list):
        #     memory[task_name][str(kk_h)]['hess_diag'] = hdiag.detach().cpu().clone()

            # kk_h += 1
            # if kk_h > 4:
            #     break
        # Also keep first-order gradients for grad_list
        loss.backward()

        k_linear = 0
        for k, (m,params) in enumerate(net.named_parameters()):
            if 'weight' in m and 'bn' not in m: # 对于所有参数都展位二维的形式
                if len(params.shape) == 4: # 如果参数的维度等于4的话
                    grad = params.grad.data.detach().cpu().numpy()
                    grad = grad.reshape(grad.shape[0], grad.shape[1]*grad.shape[2]*grad.shape[3])
                    grad_list.append(grad)
                else:
                    if 'fc3' in m and k_linear == task_id:
                        grad = params.grad.data.detach().cpu().numpy()
                        grad_list.append(grad)
                        k_linear += 1
                    elif 'fc3' not in m:
                        grad = params.grad.data.detach().cpu().numpy()
                        grad_list.append(grad)


    for i in range(len(net.map)):
        bsz=batch_list[i]
        k=0
        if i<3: # 卷积
            ksz= net.ksize[i]
            s=compute_conv_output_size(net.map[i],net.ksize[i])
            # logging.info("s:{}".format(s))
            mat = np.zeros((net.ksize[i]*net.ksize[i]*net.in_channel[i],s*s*bsz))
            act = net.act[act_key[i]].detach().cpu().numpy()
            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:,k]=act[kk,:,ii:ksz+ii,jj:ksz+jj].reshape(-1)
                        k +=1
            mat_list.append(mat) # act的张量是全部展开为一列，然后给mat的其中一列中
        else: # 剩下的层
            act = net.act[act_key[i]].detach().cpu().numpy()
            activation = act[0:bsz].transpose()
            mat_list.append(activation)

    return mat_list, grad_list


def update_SGP (args, model, device, memory, task_name, mat_list, grad_list, threshold, task_id, feature_list=[], importance_list=[]):
    plt.figure(figsize=(10, 6))
    if not feature_list:
        # After First Task 
        for i in range(len(mat_list)):
            activation = mat_list[i]
            U,S,Vh = np.linalg.svd(activation, full_matrices=False)

            sval_total = (S**2).sum()
            sval_ratio = (S**2)/sval_total
            r = np.sum(np.cumsum(sval_ratio)<threshold[i]) #+1  

            feature_list.append(U[:,0:r])
            # save into memory
            memory[task_name][str(i)]['space_list'] = torch.tensor(U[:, 0:r], dtype=torch.float32, device=device)

            importance = ((args.scale_coff+1)*S[0:r])/(args.scale_coff*S[0:r] + max(S[0:r])) 
            importance_list.append(importance)

    else:
        for i in range(len(mat_list)):
            activation = mat_list[i]

            delta = []
            R2 = np.dot(activation, activation.transpose())
            for ki in range(feature_list[i].shape[1]):
                space = feature_list[i].transpose()[ki]
                # print(space.shape)
                delta_i = np.dot(np.dot(space.transpose(), R2), space)
                # print(delta_i)
                delta.append(delta_i)
            delta = np.array(delta)

            U1,S1,Vh1=np.linalg.svd(activation, full_matrices=False)
            sval_total = (S1**2).sum()
            # Projected Representation (Eq-4)
            act_proj = np.dot(np.dot(feature_list[i],feature_list[i].transpose()),activation)
            r_old = feature_list[i].shape[1] # old GPM bases 
            Uc,Sc,Vhc = np.linalg.svd(act_proj, full_matrices=False)
            importance_new_on_old = np.dot(np.dot(feature_list[i].transpose(),Uc[:,0:r_old])**2, Sc[0:r_old]**2) ## r_old no of elm s**2 fmt
            importance_new_on_old = np.sqrt(importance_new_on_old)
            
            act_hat = activation - act_proj
            U,S,Vh = np.linalg.svd(act_hat, full_matrices=False)
            # criteria (Eq-5)
            sval_hat = (S**2).sum()
            sval_ratio = (S**2)/sval_total               
            accumulated_sval = (sval_total-sval_hat)/sval_total
            sigma = S ** 2

            # =3 stack delta and sigma in a same list, then sort in descending order
            stack = np.hstack((delta, sigma))  # [0,..30, 31..99]
            stack_index = np.argsort(stack)[::-1]  # [99, 0, 4,7...]
            # print('stack index:{}'.format(stack_index))
            stack = np.sort(stack)[::-1]


            # =4 select the most import basis
            r1_pre = len(delta)
            r1 = 0
            accumulated_sval1 = 0
            for ii1 in range(len(stack)):
                if accumulated_sval1 < threshold[i] * sval_total:
                    accumulated_sval1 += stack[ii1]
                    r1 += 1
                    if r1 == activation.shape[0]:
                        break
                else:
                    break

            # =5 save the corresponding space
            U_list = np.hstack((feature_list[i], U))
            sel_index = stack_index[:r1]
            memory[task_name][str(i)]['space_list'] = torch.tensor(U_list[:, sel_index], dtype=torch.float32, device=device)
            
            r = 0
            for ii in range (sval_ratio.shape[0]):
                if accumulated_sval < threshold[i]:
                    accumulated_sval += sval_ratio[ii]
                    r += 1
                else:
                    break
            if r == 0:
                print ('Skip Updating GPM for layer: {}'.format(i+1)) 
                # update importances 
                importance = importance_new_on_old
                importance = ((args.scale_coff+1)*importance)/(args.scale_coff*importance + max(importance)) 
                importance [0:r_old] = np.clip(importance [0:r_old]+importance_list[i][0:r_old], 0, 1)
                importance_list[i] = importance # update importance
                continue
            # update GPM
            Ui=np.hstack((feature_list[i],U[:,0:r]))  
            # update importance 
            importance = np.hstack((importance_new_on_old,S[0:r]))
            importance = ((args.scale_coff+1)*importance)/(args.scale_coff*importance + max(importance))         
            importance [0:r_old] = np.clip(importance [0:r_old]+importance_list[i][0:r_old], 0, 1) 

            if Ui.shape[1] > Ui.shape[0] :
                feature_list[i]=Ui[:,0:Ui.shape[0]]
                importance_list[i] = importance[0:Ui.shape[0]]
            else:
                feature_list[i]=Ui
                importance_list[i] = importance

    print('-'*40)
    print('Gradient Constraints Summary')
    print('-'*40)
    for i in range(len(feature_list)):
        print ('Layer {} : {}/{}'.format(i+1,feature_list[i].shape[1], feature_list[i].shape[0]))
    print('-'*40)
    return feature_list, importance_list


def importance_loss(model, memory, device, task_name_list, task_id, feature_list, importance_list):
    importance_all = []
    kk = 0
    for k, (pname, param) in enumerate(model.named_parameters()):
        if 'weight' in pname and 'bn' not in pname and k < 15 and len(param.size()) != 1:
            F, r = feature_list[kk].shape
            layer_scores = []
            # ----- For every basis vector -----
            for f in range(r):
                # Δw
                basis = torch.tensor(feature_list[kk][:, f], dtype=torch.float32).to(device).unsqueeze(1)  # ←❷
                quad_acc = 0.0
                # ----- For every task -----
                for t_idx in range(task_id+1):
                    slot = memory[task_name_list[t_idx]][str(kk)]
                    U = slot['space_list']
                    H = slot['hess_diag'].to(device)

                    proj = U @ (U.t() @ basis)
                    proj2 = proj.pow(2)

                    out_ch = param.shape[0]
                    F_dim = proj2.size(0)
                    proj_full = proj2.t().expand(out_ch, F_dim)

                    H_full = H.view(out_ch, F_dim)
                    # print(H_full.shape, proj_full.shape, proj2.shape)
                    quad_acc += 0.5 * (H_full * proj_full).sum()
                # —— mean(quad_acc) ——
                layer_scores.append( quad_acc / max(1, task_id + 1) )
                # print(layer_scores)
            # --- mean‑center then rescale to [0, 1] ---
            scores = torch.tensor(layer_scores, dtype=torch.float32, device=device)  # [r]
            scores_norm = ((args.scale_coff_Taylor + 1) * scores) / (args.scale_coff_Taylor * scores + max(scores))
            # print("scores,scores_norm", scores, scores_norm)
            importance_all.append(scores_norm.cpu())
            kk += 1
    return importance_all



def grad_proj_cond(args, net, x, y, memory, task_name, task_id, task_name_list, device, optimizer, criterion):
    '''
    get the regime descision
    '''

    # calcuate the gradient for current task before training
    steps = 1
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    b = r[0:125]  # Take 125*10 random samples
    x, y = x.to(device), y.to(device)
    example_data = x[b]
    example_data, target = example_data.to(device), y[b].to(device)

    batch_list = [2 * 12, 100, 100, 125, 125]
    grad_list = []  # list contains gradient of each layer
    act_key = list(net.act.keys())
    # print('task id:{}'.format(task_id))
    for i in range(1):

        optimizer.zero_grad()
        example_out = net(example_data)

        loss = criterion(example_out[task_id], target)
        loss.backward()

        k_linear = 0
        for k, (m, params) in enumerate(net.named_parameters()):
            if 'weight' in m and 'bn' not in m:
                if len(params.shape) == 4:

                    grad = params.grad.data.detach().to(device)
                    grad = grad.reshape(grad.shape[0], grad.shape[1] * grad.shape[2] * grad.shape[3])
                    grad_list.append(grad)
                else:
                    if 'fc3' in m and k_linear == task_id:
                        grad = params.grad.data.detach().to(device)
                        grad_list.append(grad)
                        k_linear += 1
                    elif 'fc3' not in m:
                        grad = params.grad.data.detach().to(device)
                        grad_list.append(grad)

                        # project on each task subspace
    gradient_norm_lists_tasks = []
    for task_index in range(task_id):
        projection_norm_lists = []

        for i in range(len(grad_list)):  # layer
            space_list = memory[task_name_list[task_index]][str(i)]['space_list']
                        # grad_list is the grad for current task
            projection = torch.mm(grad_list[i], torch.mm(space_list, space_list.t()))
            projection_norm = torch.norm(projection)

            projection_norm_lists.append(projection_norm)
            gradient_norm = torch.norm(grad_list[i])

            # make decision if Regime 1
            # logging.info('project_norm:{}, threshold for regime 1:{}'.format(projection_norm, eplison_1 * gradient_norm))
            if projection_norm <= Eplison_1 * gradient_norm:
                memory[task_name][str(i)]['regime'][task_index] = '1'
            else:

                memory[task_name][str(i)]['regime'][task_index] = '2'
        gradient_norm_lists_tasks.append(projection_norm_lists)
        # for i in range(len(grad_list)):
        #     print_log('Layer:{}, Regime:{}'.format(i, memory[task_name][str(i)]['regime'][task_index]), log)


    if task_id == 1:
        for i in range(len(grad_list)):
            memory[task_name][str(i)]['selected_task'] = [0]
    else:
        k = 2

        for layer in range(len(grad_list)):
            task_norm = []
            for t in range(len(gradient_norm_lists_tasks)):
                norm = gradient_norm_lists_tasks[t][layer]
                task_norm.append(norm)
            # task_norm = np.array(task_norm)
            # idx = np.argpartition(task_norm, -k)[-k:]
            # Convert to torch tensor and select top-k indices
            task_norm_t = torch.tensor(task_norm, device=device)
            _, idx = torch.topk(task_norm_t, k, largest=True)
            # Select tasks whose projection norm exceeds threshold
            # Use Eplison_1 as the default threshold
            # norm_thresh = getattr(args, 'proj_norm_threshold', Eplison_1)
            # idx = np.where(task_norm > norm_thresh)[0]
            memory[task_name][str(layer)]['selected_task'] = idx


def main(args):
    tstart=time.time()
    ## Device Setting 
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print (device)
    ## setup seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True
    # torch.set_float32_matmul_precision('high')

    ## Load CIFAR100 DATASET
    from dataloader import cifar100 as cf100
    data,taskcla,inputsize=cf100.get(pc_valid=args.pc_valid)

    acc_matrix=np.zeros((10,10))
    criterion = torch.nn.CrossEntropyLoss()

    task_id = 0
    task_list = []
    task_name_list = []
    memory = {}
    for k,ncla in taskcla:
        # specify threshold hyperparameter
        threshold = np.array([args.gpm_eps] * 5) + task_id * np.array([args.gpm_eps_inc] * 5)

        task_name = data[k]['name']
        task_name_list.append(task_name)
        print('*'*100)
        print('Task {:2d} ({:s})'.format(k,data[k]['name']))
        print('*'*100)
        xtrain=data[k]['train']['x']
        ytrain=data[k]['train']['y']
        xvalid=data[k]['valid']['x']
        yvalid=data[k]['valid']['y']
        xtest =data[k]['test']['x']
        ytest =data[k]['test']['y']
        task_list.append(k)

        lr = args.lr 
        best_loss=np.inf
        print ('-'*40)
        print ('Task ID :{} | Learning Rate : {}'.format(task_id, lr))
        print ('-'*40)
        
        if task_id==0:
            model = AlexNet(taskcla).to(device)
            # print ('Model parameters ---')
            # for k_t, (m, param) in enumerate(model.named_parameters()):
            #     print (k_t,m,param.shape)
            # print ('-'*40)

            memory[task_name] = {}

            # print_log('Model parameters ---', log)
            kk = 0
            for k_t, (m, param) in enumerate(model.named_parameters()):
                if 'weight' in m and 'bn' not in m:
                    # print_log((k_t, m, param.shape), log)
                    memory[task_name][str(kk)] = {
                        'space_list': {},
                        'grad_list': {},
                        'regime': {},
                    }
                    kk += 1

            best_model=get_model(model)
            feature_list =[]
            importance_list = []
            space_list_all = []
            optimizer = optim.SGD(model.parameters(), lr=lr)
            for epoch in range(1, args.n_epochs+1):
                # Train
                clock0=time.time()
                train(args, model, device, xtrain, ytrain, optimizer, criterion, k, lr)
                clock1=time.time()
                tr_loss,tr_acc = test(args, model, device, xtrain, ytrain,  criterion, k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                            tr_loss,tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid,  criterion, k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test
            print ('-'*40)
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion, k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))
            # Memory and Importance Update  
            # mat_list = get_representation_matrix(model, device, xtrain, ytrain)
            mat_list, grad_list = get_representation_and_gradient_Taylor(args, model, device, optimizer, criterion, task_name, memory, k, xtrain, ytrain)
            feature_list, importance_list = update_SGP(args, model, device, memory, task_name, mat_list, grad_list, threshold, task_id, feature_list, importance_list)
            wanted = {'conv1.weight',
                      'conv2.weight',
                      'conv3.weight',
                      'fc1.weight',
                      'fc2.weight'}
            fisher = compute_fisher_diag(model, criterion,
                                         inputs=xtrain,
                                         targets=ytrain,
                                         batch_size=64,
                                         n_batches=60,
                                         wanted_layers=wanted,
                                         device=device,
                                         task_id=task_id)



            for kk_h, (pname, fdiag) in enumerate(fisher.items()):
                memory[task_name][str(kk_h)]['hess_diag'] = fdiag.clone().cpu()
            importance_all = importance_loss(model, memory, device, task_name_list, task_id, feature_list, importance_list)

        else:

            memory[task_name] = {}
            kk = 0
            # print_log("reinit the scale for each task", log)
            for k_t, (m, params) in enumerate(model.named_parameters()):
                # create the saved memory
                if 'weight' in m and 'bn' not in m:
                    memory[task_name][str(kk)] = {
                        'space_list': {},
                        'grad_list': {},
                        'space_mat_list': {},
                        'scale1': {},
                        'scale2': {},
                        'regime': {},
                        'selected_task': {},
                        'ratio': {},
                    }
                    kk += 1
            optimizer = optim.SGD(model.parameters(), lr=args.lr)
            feature_mat = []
            feature_mat_weight = []
            grad_proj_cond(args, model, xtrain, ytrain, memory, task_name, task_id, task_name_list, device, optimizer,
                           criterion)
            # Projection Matrix Precomputation

            for i in range(len(model.act)):
                Uf=torch.Tensor(np.dot(feature_list[i],np.dot(np.diag(importance_all[i]),feature_list[i].transpose()))).to(device)
                Uf.requires_grad = False
                feature_mat.append(Uf)

            for i in range(len(model.act)):
                Uf2 =torch.Tensor(np.dot(feature_list[i],np.dot(np.diag(importance_list[i]),feature_list[i].transpose()))).to(device)
                Uf2.requires_grad = False
                feature_mat_weight.append(Uf2)
            # print ('-'*40)
            for epoch in range(1, args.n_epochs+1):
                # Train 
                clock0=time.time()
                if epoch == 1:
                    pre_model = AlexNet(taskcla).to(device)
                    pre_model.load_state_dict(deepcopy(model.state_dict()))
                train_projected(args, model,pre_model, memory, task_name_list, device,xtrain, ytrain,optimizer,criterion,feature_mat,feature_mat_weight,  k, lr)

                clock1=time.time()
                tr_loss, tr_acc = test(args, model, device, xtrain, ytrain,criterion,k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                        tr_loss, tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid, criterion,k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test 
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion,k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))  
            # Memory and Importance Update
            mat_list, grad_list = get_representation_and_gradient_Taylor(args, model, device, optimizer, criterion,
                                                                         task_name, memory, task_id, xtrain, ytrain)
            feature_list, importance_list = update_SGP (args, model, device, memory, task_name, mat_list, grad_list, threshold, task_id, feature_list, importance_list)
            wanted = {'conv1.weight',
                      'conv2.weight',
                      'conv3.weight',
                      'fc1.weight',
                      'fc2.weight'}
            fisher = compute_fisher_diag(model, criterion,
                                         inputs=xtrain,
                                         targets=ytrain,
                                         batch_size=64,
                                         n_batches=60,
                                         wanted_layers=wanted,
                                         device=device,
                                         task_id=task_id)


            for kk_h, (pname, fdiag) in enumerate(fisher.items()):
                memory[task_name][str(kk_h)]['hess_diag'] = fdiag.clone().cpu()
            importance_all = importance_loss(model, memory, device, task_name_list, task_id, feature_list, importance_list)
        # save accuracy 
        jj = 0

        for ii in np.array(task_list)[0:task_id+1]:
            acc = 0.0
            xtest =data[ii]['test']['x']
            ytest =data[ii]['test']['y']
            for i in range(10):
                _, acc_s = test(args, model, device, xtest, ytest,criterion,ii)
                acc += acc_s
            acc_matrix[task_id,jj] = acc/10
            jj +=1
        print('Accuracies =')
        for i_a in range(task_id+1):
            print('\t',end='')
            for j_a in range(acc_matrix.shape[1]):
                print('{:5.1f}% '.format(acc_matrix[i_a,j_a]),end='')
            print()
        # update task id 
        task_id +=1
    print('-'*50)

    print ('Final Avg Accuracy: {:5.2f}%'.format(acc_matrix[-1].mean())) 
    bwt=np.mean((acc_matrix[-1]-np.diag(acc_matrix))[:-1]) 
    print ('Backward transfer: {:5.2f}%'.format(bwt))
    print('[Elapsed time = {:.1f} ms]'.format((time.time()-tstart)*1000))
    print('-'*50)



if __name__ == "__main__":
    # Training parameters

    parser = argparse.ArgumentParser(description='10-split CIFAR-100 with LSS')
    parser.add_argument('--pc_valid', default=0.05, type=float,
                        help='fraction of training data used for validation')
    parser.add_argument('--batch_size_train', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--batch_size_test', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--n_epochs', type=int, default=200, metavar='N',
                        help='number of training epochs/task (default: 200)')
    parser.add_argument('--seed', type=int, default=5, metavar='S',
                        help='random seed (default: 5)')

    # Optimizer parameters
    parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
                        help='learning rate (default: 0.05)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--lr_min', type=float, default=1e-4, metavar='LRM',
                        help='minimum lr rate (default: 1e-5)')
    parser.add_argument('--lr_patience', type=int, default=7, metavar='LRP',
                        help='hold before decaying lr (default: 6)')
    parser.add_argument('--lr_factor', type=int, default=2, metavar='LRF',
                        help='lr decay factor (default: 2)')
    # SGP/GPM specific 
    parser.add_argument('--scale_coff', type=int, default=7, metavar='SCF',#7
                        help='importance co-efficeint (default: 7)')
    parser.add_argument('--scale_coff_Taylor', type=int, default=10, metavar='SCF',# 10
                        help='importance co-efficeint (default: 10)')
    parser.add_argument('--gpm_eps', type=float, default=0.97, metavar='EPS',
                        help='threshold (default: 0.97)')
    parser.add_argument('--gpm_eps_inc', type=float, default=0.003, metavar='EPSI',
                        help='threshold increment per task (default: 0.003)')
    # Hessian-regularization
    parser.add_argument('--lambda_gp_ql', type=float, default=1e2, metavar='LG', # 1e2
                        help='coefficient for first‑order gradient flatness regularization')

    args = parser.parse_args()
    print('='*100)
    print('Arguments =')
    for arg in vars(args):
        print('\t'+arg+':',getattr(args,arg))
    print('='*100)

    main(args)



