import os
import sys
import time
import random
import pickle
import argparse
import numpy as np
import pandas as pd
from numpy import linalg as LA
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
from tinyimagenet import TinyImageNet

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.func import grad, vmap, functional_call
from torchvision import datasets, transforms, models

def evaluate(model, test_loader, device, use_noisy_labels=False):

    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for _, inputs, true_labels, noisy_labels in tqdm(test_loader):
            
            if use_noisy_labels == False:
                labels = true_labels
            else:
                labels = noisy_labels
            
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = torch.nn.CrossEntropyLoss()(outputs, labels)
    
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    
    return epoch_loss, epoch_acc
    
def cost_model(input_, target, params, buffers, model):
    
    input_ = input_.unsqueeze(0)
    target = target.unsqueeze(0)

    out_ = functional_call(model, (params, buffers), input_)
    cost = torch.nn.CrossEntropyLoss()(out_, target)

    return cost

def compute_similarity(task_vector, sample_grads, method="normed_proj", temperature=1.0):
    
    tv_flat = task_vector.flatten()
    normed_tv_flat = torch.nn.functional.normalize(tv_flat, p=2, dim=0)
    
    score_tensor = torch.empty(len(sample_grads), device=tv_flat.device)

    for i, neg_grad in enumerate(sample_grads):
        neg_grad_flat = neg_grad.flatten()
        
        ## compute mimic score ##
        if method.endswith("cos"):
            normed_neg_grad_flat = torch.nn.functional.normalize(neg_grad_flat, p=2, dim=0)
            score = torch.dot(normed_neg_grad_flat, normed_tv_flat)
        elif method.endswith("proj"):
            score = torch.dot(neg_grad_flat, normed_tv_flat)
            
        score_tensor[i] = score.item()

    if method.startswith("normed"):
        score_tensor = score_tensor / temperature
        normed_score_tensor = torch.nn.Softmax(dim=0)(score_tensor)
        return normed_score_tensor
    else:
        return score_tensor
    
def compute_task_vector(current_model, reference_model, mimic_layer_name):

    current_model_state_dict = current_model.state_dict()
    reference_model_state_dict = reference_model.state_dict()
    task_vector = reference_model_state_dict[mimic_layer_name] - current_model_state_dict[mimic_layer_name]
    
    return task_vector    
    
def objective_fn_matrix(neg_grad, task_vector, w, lambd, norm_way="l2"):
    
    slice_norms = []
    
    for i in range(neg_grad.shape[0]):
        neg_grad_slice = neg_grad[i, :, :] * w[i]
        slice_norm = cp.norm(neg_grad_slice - task_vector, 'fro')
        slice_norms.append(slice_norm)

    if norm_way == "l2":
        total_norm = cp.sum(cp.hstack(slice_norms)) + lambd * cp.norm2(w)
    elif norm_way == "l1":
        total_norm = cp.sum(cp.hstack(slice_norms)) + lambd * cp.norm1(w)

    return total_norm

def objective_fn_vector(neg_grad, task_vector, w, lambd, norm_way="l2"):

    if norm_way == "l2":
        return cp.norm2((w @ neg_grad) - task_vector) + lambd * cp.norm2(w)
    elif norm_way == "l1":
        return cp.norm2((w @ neg_grad) - task_vector) + lambd * cp.norm1(w)
    
def add_noise_to_labels(labels, num_class, noise_ratio=0.1):

    labels = np.array(labels)
    num_labels = len(labels)
    num_noisy_labels = int(noise_ratio * num_labels)
    
    # Indices for which labels will be noisy
    noisy_indices = random.sample(range(num_labels), num_noisy_labels)
    
    # Make a copy of the original labels
    noisy_labels = labels.copy()
    
    for idx in noisy_indices:
        
        # Assign a new random label different from the original one
        original_label = labels[idx]
        new_label = (original_label + np.random.randint(1, num_class)) % num_class
        noisy_labels[idx] = new_label
    
    return noisy_labels, noisy_indices

def get_per_sample_gradients(model, inputs, targets):

    params = {k: v.detach() for k, v in model.named_parameters() if v.requires_grad == True}
    buffers = {k: v.detach() for k, v in model.named_buffers() if v.requires_grad == True}

    ft_grad = grad(cost_model, argnums=2)
    ft_all_grads = vmap(ft_grad, in_dims = (0, 0, None, None, None))
    ft_per_sample_grads = ft_all_grads(inputs, targets, params, buffers, model)

    return ft_per_sample_grads

def solve_subset_selection(ref, inputs, num_datapoint, lambd_value, norm_way, device):

    selection_weights = cp.Variable(num_datapoint)
    constraints = []
    # constraints = [0 <= selection_weights, selection_weights <= 1, cp.sum(selection_weights) == 1]
    lambd = cp.Parameter(nonneg=True)
    lambd.value = lambd_value
    if len(ref.shape) == 2:
        problem = cp.Problem(cp.Minimize(objective_fn_matrix(inputs, ref, selection_weights, lambd, norm_way)), constraints=constraints)
    elif len(ref.shape) == 1:
        problem = cp.Problem(cp.Minimize(objective_fn_vector(inputs, ref, selection_weights, lambd, norm_way)), constraints=constraints)
    problem.solve(solver=cp.MOSEK, verbose=False)
    selection_weights = torch.Tensor(selection_weights.value).to(device)
    
    return selection_weights

def gradient_calibration(model, per_sample_weights, per_sample_grads, mimic_layer_name, calibrate_mimic_layer_only=True):

    if calibrate_mimic_layer_only == True:
        for name, param in model.named_parameters():
            if param.requires_grad == True:
                if name == mimic_layer_name:
                    reshape_array = [1 for s in range(len(param.shape) + 1)]
                    reshape_array[0] = len(per_sample_weights)
                    reshaped_new_weights = per_sample_weights.reshape(reshape_array)
                    param.grad = torch.sum(reshaped_new_weights * per_sample_grads[name], axis=0)
                else:
                    param.grad = torch.mean(per_sample_grads[name], axis=0)
                    
    else:
        for name, param in model.named_parameters():
            if param.requires_grad == True:
                reshape_array = [1 for s in range(len(param.shape) + 1)]
                reshape_array[0] = len(per_sample_weights)
                reshaped_new_weights = per_sample_weights.reshape(reshape_array)
                param.grad = torch.sum(reshaped_new_weights * per_sample_grads[name], axis=0)
                
# Custom Dataset class with index tracking and label noise addition
class IndexedDataset(Dataset):
    
    def __init__(self, root, name, train=True, transform=None, download=False, noise_ratio=0.0):
        
        if name == "dtd":
            if train == True: train = "train"
            else: train = "test"
            self.dataset = datasets.DTD(root=root, split=train, transform=transform, download=download)
            self.dataset.targets = np.array([label for _, label in self.dataset])
            
        if name == "stl10":
            if train == True: train = "train"
            else: train = "test"
            self.dataset = datasets.STL10(root=root, split=train, transform=transform, download=download)
            self.dataset.targets = self.dataset.labels
            
        if name == "cifar10":
            self.dataset = datasets.CIFAR10(root=root, train=train, transform=transform, download=download)

        if name == "flower102":
            if train == True: train = "train"
            else: train = "test"
            self.dataset = datasets.Flowers102(root=root, split=train, transform=transform, download=download)
            self.dataset.targets = np.array([label for _, label in self.dataset])

        if name == "country211":
            if train == True: train = "train"
            else: train = "test"
            self.dataset = datasets.Country211(root=root, split=train, transform=transform, download=download)
            
        if name == "cifar100":
            self.dataset = datasets.CIFAR100(root=root, train=train, transform=transform, download=download)

        if name == "tinyimagenet":
            if train == True: train = "train"
            else: train = "test"
            self.dataset = TinyImageNet(f"{root}/tiny-imagenet-200", split=train, transform=transform, imagenet_idx=True)

        if name == "pet":
            if train == True: train = "trainval"
            else: train = "test"
            self.dataset = datasets.OxfordIIITPet(root=root, split=train, target_types="category", transform=transform, download=download)
            self.dataset.targets = [self.dataset[i][1] for i in range(len(self.dataset))]

        if name == "svhn":
            if train == True: train = "train"
            else: train = "test"
            self.dataset = datasets.SVHN(root=root, split=train, transform=transform, download=download)
            self.dataset.targets = self.dataset.labels

        self.true_labels = self.dataset.targets
        
        if name == "flower102":
            self.num_class = 102
        elif name == "svhn":
            self.num_class = 10
        else:
            self.num_class = len(self.dataset.classes)
        
        # Apply noise to labels if noise_ratio is specified
        if noise_ratio > 0:
            self.noise_ratio = noise_ratio
            self._apply_noise_to_labels()
        else:
            self.noise_ratio = 0
            self.noisy_labels = self.dataset.targets
            self.noisy_indices = []

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data, true_target = self.dataset[index]
        noisy_target = self.noisy_labels[index]
        return index, data, true_target, noisy_target

    def _apply_noise_to_labels(self):
        self.noisy_labels, self.noisy_indices = add_noise_to_labels(self.true_labels, self.num_class, self.noise_ratio)

def load_init_model(arch_name, num_class, device, seed, pretrained=True, linear_probing=True):

    if pretrained == True:
        ## using pretrained weights as backbone ##
        if os.path.exists(f"./new_saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt"):
            model = torch.load(f"./new_saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt")
        else:
            if arch_name == 'vit-b':
                model = models.vit_b_16(weights='IMAGENET1K_V1')
            elif arch_name == 'vit-l':
                model = models.vit_l_16(weights='IMAGENET1K_V1')
            model.heads.head = torch.nn.Linear(model.heads.head.in_features, num_class)
            torch.save(model, f'./new_saved_models/pretrained_linear_probing/init_{arch_name}_{num_class}classes_seed{seed}.pt')
        if linear_probing == True:
            for name, param in model.named_parameters():
                if "heads.head" not in name: param.requires_grad = False
                else: param.requires_grad = True
    else:
        ## train from scratch ##
        if os.path.exists(f"./new_saved_models/train_from_scratch/init_{arch_name}_{num_class}classes_seed{seed}.pt"):
            model = torch.load(f"./new_saved_models/train_from_scratch/init_{arch_name}_{num_class}classes_seed{seed}.pt")
        else:
            if arch_name == 'vit-b':
                model = models.vit_b_16(weights=None)
            elif arch_name == 'vit-l':
                model = models.vit_l_16(weights=None)
            model.heads.head = torch.nn.Linear(model.heads.head.in_features, num_class)
            torch.save(model, f'./new_saved_models/train_from_scratch/init_{arch_name}_{num_class}classes_seed{seed}.pt')
    return model
    
def parse_arguments():
    
    parser = argparse.ArgumentParser(description='Argument parser for training script')

    parser.add_argument('--mode', type=str, choices=['grad-mimic', 'grad-match', 'grad-descent', 'grad-norm', 'agra'], required=True,
                        help='Select the mode for training: grad-norm, agra, grad-match, grad-mimic, or grad-descent (baseline)')
    parser.add_argument('--method', type=str, default='normed_proj', choices=['opt', 'cos', 'proj', 'normed_cos', 'normed_proj'], required=False,
                        help='Select the method for reweighting gradients: optimization, (normed) cosine similarity, or (normed) projection length')    
    parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['svhn', 'dtd', 'pet', 'stl10', 'cifar10', 'flower102', 'country211', 'cifar100', 'tinyimagenet'], required=False,
                        help='Name of the dataset to use')
    parser.add_argument('--noisy_level', type=float, default=0.0, required=False,
                        help='Percentage of involved noisy labels')
    parser.add_argument('--num_epoch', type=int, default=5, required=False,
                        help='Number of epochs to train')
    parser.add_argument('--model_arch', type=str, default='vit-b', choices=['vit-b', 'vit-l'], required=False,
                        help='Type of model architecture to use')
    parser.add_argument('--pretrained', default=True, required=False, action=argparse.BooleanOptionalAction,
                        help='To load pretrained weights (IMAGENET_V1) or not')
    parser.add_argument('--linear_probing', default=True, required=False, action=argparse.BooleanOptionalAction,
                        help='Fine-tune on the top of backbone only')
    parser.add_argument('--mimic_layer', type=str, default='heads.head.weight', required=False, 
                        help='Specific layer name to use')
    parser.add_argument('--calibrate_mimic_layer_only', default=True, required=False, action=argparse.BooleanOptionalAction,
                        help='Only calibrate gradients for mimic layer or not')
    parser.add_argument('--temperature', type=float, default=1.0, required=False,
                        help='To control the smoonthness of selection weights')
    parser.add_argument('--starting_epoch', type=int, default=0, required=False,
                        help='Starting epoch for using grad-mimic')
    parser.add_argument('--training_batch_size', type=int, default=32, choices=[16, 32, 64, 128, 256], required=False,
                        help='Training batch size')
    parser.add_argument('--dataset_dir', type=str, default='./data', required=False,
                        help='Directory path to load data')
    parser.add_argument('--optimizer', type=str, default='adamw', choices=['adamw', 'sgd'], required=False,
                        help='Optimizer to use')
    parser.add_argument('--learning_rate', type=float, default=1e-4, required=False,
                        help='Learning rate to train')
    parser.add_argument('--norm_way', type=str, default="l2", choices=["l2", "l1"], required=False,
                        help='Regularization approach')
    parser.add_argument('--lambda_value', type=float, default=0.0, required=False,
                        help='Regularization parameter')
    parser.add_argument('--cuda_device', type=int, default=0, required=False,
                        help='Select the CUDA device ID to use for training')
    parser.add_argument('--seed', type=int, default=123, required=False,
                        help='Seed value for random number generator')
    
    args = parser.parse_args()
    
    return args

if __name__ == "__main__":
    args = parse_arguments()
    print("-----------------------Arguments-----------------------")
    print(f"Mode: {args.mode}")
    print(f"Method: {args.method}")
    print(f"Dataset Name: {args.dataset_name}")
    print(f"Noisy Level: {args.noisy_level}")
    print(f"Number of Epochs: {args.num_epoch}")
    print(f"Model Arch: {args.model_arch}")
    print(f"Pretrained: {args.pretrained}")
    print(f"Linear Probing: {args.linear_probing}")
    print(f"Mimic Layer: {args.mimic_layer}")
    print(f"Calibrate Mimic Layer only: {args.calibrate_mimic_layer_only}")
    print(f"Temperature: {args.temperature}")
    print(f"Starting Epoch: {args.starting_epoch}")
    print(f"Training Batch Size: {args.training_batch_size}")
    print(f"Dataset Directory: {args.dataset_dir}")
    print(f"Optimizer: {args.optimizer}")
    print(f"Learning Rate: {args.learning_rate}")
    print(f"Norm Way: {args.norm_way}")
    print(f"Lambda: {args.lambda_value}")
    print(f"CUDA Device: {args.cuda_device}")
    print(f"Seed: {args.seed}")
    print("-------------------------------------------------------")

    if args.dataset_name in ["dtd", "flower102", "country211"] and args.num_epoch != 10:
        print("** Num of Epoch is incorrect. Change to 10 Epochs. **")
        args.num_epoch = 10
    
    ## Setup Everything ##
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    test_batch_size = 512
    device = torch.device(f'cuda:{args.cuda_device}' if torch.cuda.is_available() else 'cpu')
    
    ## Data preprocessing and augmentation ##
    transform_train = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    ## Create instances of IndexedDataset ##
    train_dataset = IndexedDataset(root=args.dataset_dir, name=args.dataset_name, train=True, download=False, transform=transform_train, noise_ratio=args.noisy_level)
    test_dataset = IndexedDataset(root=args.dataset_dir, name=args.dataset_name, train=False, download=False, transform=transform_test, noise_ratio=0.0)
    num_class = train_dataset.num_class
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.training_batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=4)

    ## Load initial model ##
    model = load_init_model(arch_name=args.model_arch, num_class=num_class, device=device, seed=args.seed, pretrained=args.pretrained, linear_probing=args.linear_probing)
    model = model.to(device)
    
    ## Evaluate the initial model ##
    init_true_train_loss, init_true_train_acc = evaluate(model, train_loader, device)
    print(f'Init. Model -- Train True Loss: {init_true_train_loss:.4f}, Train True Acc: {init_true_train_acc:.4f}')
    
    init_noisy_train_loss, init_noisy_train_acc = evaluate(model, train_loader, device, use_noisy_labels=True)
    print(f'Init. Model -- Train Noisy Loss: {init_noisy_train_loss:.4f}, Train Noisy Acc: {init_noisy_train_acc:.4f}')
    
    init_test_loss, init_test_acc = evaluate(model, test_loader, device)
    print(f'Init. Model -- Test Loss: {init_test_loss:.4f}, Test Acc: {init_test_acc:.4f}')
    
    ## Record learning results ##
    learning_results = {
        "true_training_loss": [init_true_train_loss] + [0 for t in range(args.num_epoch)],
        "noisy_training_loss": [init_noisy_train_loss] + [0 for t in range(args.num_epoch)],
        "testing_loss": [init_test_loss] + [0 for t in range(args.num_epoch)],
        "true_training_accuracy": [init_true_train_acc] + [0 for t in range(args.num_epoch)],
        "noisy_training_accuracy": [init_noisy_train_acc] + [0 for t in range(args.num_epoch)],
        "testing_accuracy": [init_test_acc] + [0 for t in range(args.num_epoch)]
    }
    
    ## Record per-sample weights ##
    result_collection = {}
    for i in range(len(train_dataset)):
        
        if i not in result_collection:
            
            if i in train_dataset.noisy_indices: 
                status = "incorrect"
            else:
                status = "correct"
                
            result_collection[i] = {
                "status": status, \
                "per_sample_weights": [0 for t in range(args.num_epoch)]
            }
        
    ## Load reference model ##
    if args.mode == "grad-mimic":

        if args.pretrained == True and args.linear_probing == True:
            reference_model_path = f"./new_saved_models/pretrained_linear_probing/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_grad-descent_noisy0.0_seed123.pt"
        elif args.pretrained == True and args.linear_probing == False:
            reference_model_path = f"./new_saved_models/pretrained_fine_tune_all/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_grad-descent_noisy0.0_seed123.pt"
        else:
            reference_model_path = f"./new_saved_models/train_from_scratch/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_grad-descent_noisy0.0_seed123.pt"
            
        if os.path.exists(reference_model_path) == False:
            print("** Reference model is not existed. Change to grad-descent's mode. Please run grad-mimic later. **")
            args.mode = "grad-descent"
        else:
            reference_model = torch.load(reference_model_path).to(device)
            reference_model.eval()
            ## Evaluate the reference model ##
            ref_test_loss, ref_test_acc = evaluate(reference_model, test_loader, device)
            print(f'Ref. Model -- Test Loss: {ref_test_loss:.4f}, Test Acc: {ref_test_acc:.4f}')
            
    if args.optimizer == "adamw":
        optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-5)
    elif args.optimizer == "sgd":
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=1e-5, momentum=0.9)
    
    for epoch in tqdm(range(args.num_epoch), total=args.num_epoch):
        
        print("Train the model")
        model.train()
        running_noisy_loss = 0.0
        running_true_loss = 0.0
        noisy_correct = 0
        true_correct = 0
        total = 0
    
        for batch_idx, (indices, inputs, true_labels, noisy_labels) in enumerate(train_loader):
                
            noisy_labels = noisy_labels.type(torch.int64)
            true_labels = true_labels.type(torch.int64)
                
            optimizer.zero_grad()
            inputs, true_labels, noisy_labels = inputs.to(device), true_labels.to(device), noisy_labels.to(device)
            
            ## Grad-Norm algorithm (our baseline) ##
            if args.mode == "grad-norm" and epoch >= args.starting_epoch:
                if args.method != "norm":
                    args.method = "norm"
    
                ## Compute per-sample grad ##
                ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)
    
                ## Compute gradient norm ##
                specific_layer_grads = ft_per_sample_grads[args.mimic_layer]
                per_sample_weights = torch.zeros(len(inputs), device=device)
                for g, _grad in enumerate(specific_layer_grads):
                    per_sample_weights[g] = torch.norm(_grad).item()
                per_sample_weights = torch.nn.Softmax(dim=0)(per_sample_weights)
    
                ## Record things ##
                for loc, datapoint_index in enumerate(indices):
                    result_collection[datapoint_index.item()]["per_sample_weights"][epoch] = per_sample_weights[loc].item()
    
                ## Calibrate Gradients ##
                gradient_calibration(model, per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)
                optimizer.step()
    
            ## AGRA algorithm (our competitor) ##
            elif args.mode == "agra" and epoch >= args.starting_epoch:
                if args.method != "cos":
                    args.method = "cos"
                
                ## Grab another batch ##
                comp_indices, comp_inputs, comp_true_labels, comp_noisy_labels = next(iter(train_loader))
                comp_inputs = comp_inputs.to(device)
                comp_noisy_labels = comp_noisy_labels.type(torch.int64)
                comp_noisy_labels = comp_noisy_labels.to(device)
    
                ## Computer per-sample grad for both ##
                comp_ft_per_sample_grads = get_per_sample_gradients(model, comp_inputs, comp_noisy_labels)
                ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)
    
                ## Compute mean gradient on compared batch ##
                comp_specific_layer_grad_mean = comp_ft_per_sample_grads[args.mimic_layer].mean(axis=0)
                specific_layer_grads = ft_per_sample_grads[args.mimic_layer]
                
                per_sample_weights = compute_similarity(comp_specific_layer_grad_mean, specific_layer_grads, method=args.method)
                
                ## if similarity is negative, we discard it ##
                per_sample_weights = torch.clamp(per_sample_weights, min=0)
                non_zero_count = (per_sample_weights > 0).sum()
                new_per_sample_weights = torch.zeros_like(per_sample_weights)
                if non_zero_count.item() != 0:
                    new_per_sample_weights[per_sample_weights > 0] = 1.0 / non_zero_count.item()
                else:
                    new_per_sample_weights = torch.full((len(inputs), ), 1 / len(inputs), device=device)
                     
                ## Calibrate Gradients ##
                gradient_calibration(model, new_per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)
                optimizer.step()
            
            ## Grad-Match algorithm (our competitor) ##
            elif args.mode == "grad-match" and epoch >= args.starting_epoch:
                if args.method != "opt":
                    args.method = "opt"
                
                ## Compute per-sample grad ##
                ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)
    
                ## Compute mean gradient ##
                specific_layer_grad_mean = ft_per_sample_grads[args.mimic_layer].mean(axis=0)
                specific_layer_grads = ft_per_sample_grads[args.mimic_layer]
    
                ## Solve subset selection problem ##
                specific_layer_grad_mean = specific_layer_grad_mean.cpu().detach().numpy()
                specific_layer_grads = specific_layer_grads.cpu().detach().numpy()
                per_sample_weights = solve_subset_selection(specific_layer_grad_mean, specific_layer_grads, inputs.size(0), args.lambda_value, args.norm_way, device)
    
                ## Record things ##
                for loc, datapoint_index in enumerate(indices):
                    result_collection[datapoint_index.item()]["per_sample_weights"][epoch] = per_sample_weights[loc].item()
    
                ## Calibrate Gradients ##
                gradient_calibration(model, per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)
                optimizer.step()
    
            ## Grad-Mimic algorithm (our method) ##
            elif args.mode == "grad-mimic" and epoch >= args.starting_epoch:
                
                ## Compute per-sample grad ##
                ft_per_sample_grads = get_per_sample_gradients(model, inputs, noisy_labels)
    
                ## Compute task vector and negative gradients ##
                specific_layer_task_vector = compute_task_vector(model, reference_model, args.mimic_layer)
                specific_layer_neg_grads = -ft_per_sample_grads[args.mimic_layer]
    
                if args.method != "opt":
                    per_sample_weights = compute_similarity(specific_layer_task_vector, specific_layer_neg_grads, args.method, args.temperature)
                else:
                    ## Solve subset selection problem ##
                    specific_layer_task_vector = specific_layer_task_vector.cpu().detach().numpy()
                    specific_layer_neg_grads = specific_layer_neg_grads.cpu().detach().numpy()
                    per_sample_weights = solve_subset_selection(specific_layer_task_vector, specific_layer_neg_grads, inputs.size(0), args.lambda_value, args.norm_way, device)
                
                ## Record things ##
                for loc, datapoint_index in enumerate(indices):
                    result_collection[datapoint_index.item()]["per_sample_weights"][epoch] = per_sample_weights[loc].item()
    
                ## Calibrate Gradients ##
                gradient_calibration(model, per_sample_weights, ft_per_sample_grads, args.mimic_layer, args.calibrate_mimic_layer_only)
                optimizer.step()
    
            ## Grad-Descent algorithm (our baseline) ##
            else:
                outputs = model(inputs)
                loss = torch.nn.CrossEntropyLoss()(outputs, noisy_labels)
                loss.backward()
                optimizer.step()
    
            updated_outputs = model(inputs)
            
            updated_noisy_loss = torch.nn.CrossEntropyLoss()(updated_outputs, noisy_labels)
            running_noisy_loss += updated_noisy_loss.item() * true_labels.size(0)
    
            updated_true_loss = torch.nn.CrossEntropyLoss()(updated_outputs, true_labels)
            running_true_loss += updated_true_loss.item() * true_labels.size(0)
            
            _, predicted = updated_outputs.max(1)
            
            total += true_labels.size(0)
            noisy_correct += predicted.eq(noisy_labels).sum().item()
            true_correct += predicted.eq(true_labels).sum().item()
            
        noisy_train_loss = running_noisy_loss / total
        noisy_train_acc = noisy_correct / total
        
        true_train_loss = running_true_loss / total
        true_train_acc = true_correct / total
    
        print("Evaluate the model")
        test_loss, test_acc = evaluate(model, test_loader, device)
    
        print(f'Epoch {epoch + 1}/{args.num_epoch}')
        print(f'Noisy Train Loss: {noisy_train_loss:.4f}, Noisy Train Acc: {noisy_train_acc:.4f}')
        print(f'True Train Loss: {true_train_loss:.4f}, True Train Acc: {true_train_acc:.4f}')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
        
        learning_results["true_training_loss"][epoch + 1] = true_train_loss
        learning_results["noisy_training_loss"][epoch + 1] = noisy_train_loss
        learning_results["testing_loss"][epoch + 1] = test_loss
    
        learning_results["true_training_accuracy"][epoch + 1] = true_train_acc
        learning_results["noisy_training_accuracy"][epoch + 1] = noisy_train_acc
        learning_results["testing_accuracy"][epoch + 1] = test_acc
        
## save model and results ##
if args.pretrained == True and args.linear_probing == True: sub_folder = "pretrained_linear_probing"
elif args.pretrained == True and args.linear_probing == False: sub_folder = "pretrained_fine_tune_all"
else: sub_folder = "train_from_scratch"

if args.calibrate_mimic_layer_only == True: layer_only = "only"
else: layer_only = "not_only"
        
if args.mode != "grad-descent":

    with open(f'./new_saved_logs/{sub_folder}/{args.model_arch}_{args.mimic_layer}_{layer_only}_{args.dataset_name}_{args.method}_temp{args.temperature}_run{args.num_epoch}epochs_{args.mode}_noisy{args.noisy_level}_seed{args.seed}_results.pkl', 'wb') as f:
        pickle.dump(learning_results, f)

    with open(f'./new_saved_logs/{sub_folder}/{args.model_arch}_{args.mimic_layer}_{layer_only}_{args.dataset_name}_{args.method}_temp{args.temperature}_run{args.num_epoch}epochs_{args.mode}_noisy{args.noisy_level}_seed{args.seed}_weights.pkl', 'wb') as f:
        pickle.dump(result_collection, f)
    
else:

    with open(f'./new_saved_logs/{sub_folder}/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_{args.mode}_noisy{args.noisy_level}_seed{args.seed}_results.pkl', 'wb') as f:
        pickle.dump(learning_results, f)

    if args.noisy_level == 0.0:
        torch.save(model, f'./new_saved_models/{sub_folder}/{args.model_arch}_{args.dataset_name}_run{args.num_epoch}epochs_{args.mode}_noisy{args.noisy_level}_seed{args.seed}.pt')
