from sklearn.metrics import accuracy_score
from utils import *


parser = argparse.ArgumentParser(description='Train a good baseline model on Noisy CIFAR-10/100 dataset')
parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='Choose the dataset')
parser.add_argument('--noise_mode', type=str, default='anti-diagonal',
                    choices=['clean_label', 'aggre_label', 'rand_1_label', 'rand_2_label', 'rand_3_label', 'worst_label', 'symmetric_flip_label', 'promix_100_label',
                             'promix_400_label', 'anti-diagonal'],
                    help='Noise mode for labels: random or human')
parser.add_argument('--symmetric_flip_prob', type=float, default=0.99, help='Probability of symmetric label flipping')
parser.add_argument('--feature_type', type=str, default='foundation_model',
                    choices=['original', 'transfer_learning', 'contrastive_learning', 'foundation_model'],
                    help='feature type for training linear model')
parser.add_argument('--encoder_name', type=str, default="dinov2_vitl14", choices=[None, 'resnet18', 'resnet34', 'resnet50', 'dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'])

# parser.add_argument('--num_epochs', type=int, default=20,
#                     help='Number of epochs/iterations for training')
parser.add_argument("--batch_size", default=64, type=int, help="Batch size used during feature extraction.")

# add parser: optimizer - sgd or adam or sgd with lr scheduler
# parser.add_argument('--optimizer', type=str, default='lbfgs', choices=['lbfgs'])

args = parser.parse_args()


def main():

    dataset_names = ["cifar10"] 
    noise_percentages = {
        "cifar10": [0.91, 0.93, 0.95, 0.97, 0.99],#[0.99, 0.97, 0.95, 0.93, 0.91, 0.89, 0.85, 0.8, 0.6, 0.4, 0.2],
        "cifar100": [0.99999999, 0.9999, 0.999, 0.997, 0.995, 0.993, 0.991, 0.9, 0.8, 0.6, 0.4, 0.2]
    }
    algorithms = ["Logistic_regression"] # "SVM", "MLP"
    feature_types_enc_names = [["foundation_model","dinov2_vitl14"], ["original", None],
                               ["transfer_learning", "resnet18"], ["transfer_learning", "resnet34"],
                               ["transfer_learning", "resnet50"]]  # ["foundation_model","dinov2_vitl14"], ["original", None], "transfer_learning", "resnet18"]]

    for feature_encoder in feature_types_enc_names:
        args.feature_type = feature_encoder[0]
        args.encoder_name = feature_encoder[1]
        for algorithm in algorithms:
            for dataset_name in dataset_names:
                for noise_perc in noise_percentages[dataset_name]:
                    args.dataset = dataset_name
                    args.symmetric_flip_prob = noise_perc
                    warnings.filterwarnings('ignore') # turn off the warnings (especially for sklearn convergence)
                    print()
                    print(f"================ {args.dataset} with {args.noise_mode} (symmetric flip prob = {args.symmetric_flip_prob}) ========================")
                    # print()
                    print(f"Linear model + feature by {args.feature_type}, encoder: {args.encoder_name}")
                    # print(f"num_epochs: {args.num_epochs}, optimizer: {args.optimizer}, batch_size: {args.batch_size}")
                    print()

                    if args.feature_type == 'original' or args.feature_type == 'contrastive_learning': # in these cases, not need to upscale cifar images
                        transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ])
                    elif args.feature_type == 'transfer_learning' or args.feature_type == 'foundation_model': # need to upscale, cuz pre-trained on ImageNet
                        transform = transforms.Compose([
                            transforms.Resize((224, 224)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                        ])

                    # Load CIFAR-10 dataset with noisy labels
                    dataset_path = './data'

                    # Load CIFAR-10 dataset
                    if args.dataset == 'cifar10':
                        trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform)
                        testset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=transform)
                        num_classes = 10
                    elif args.dataset == 'cifar100':
                        trainset = torchvision.datasets.CIFAR100(root=dataset_path, train=True, download=True, transform=transform)
                        testset = torchvision.datasets.CIFAR100(root=dataset_path, train=False, download=True, transform=transform)
                        num_classes = 100

                    # Check if CUDA is available and set PyTorch to use GPU or CPU
                    # Move model to GPU if available
                    try:
                        if torch.cuda.is_available():
                            print()
                            print(torch.cuda.get_device_name(torch.cuda.current_device()))
                            device = torch.device("cuda:0")
                        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                            device = torch.device("mps")
                        else:
                            device = torch.device("cpu")
                    except AttributeError:
                        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

                    print("Using device:", device)
                    print()

                    start_time = time.time()
                    # extract the features and get the training and testing data
                    if args.feature_type == 'original': # directly get all the data
                        trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset), shuffle=False)
                        testloader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=False)

                        # Extract the data and labels
                        X_train, y_train = next(iter(trainloader))
                        X_test, y_test = next(iter(testloader))

                        X_train = X_train.view(X_train.size(0), -1).numpy()
                        X_test = X_test.view(X_test.size(0), -1).numpy()
                        y_train = y_train.numpy()
                        y_test = y_test.numpy()
                    else: # extract the features in a batch manner (otherwise, may run out of memory)
                        # Load the  pre-trained feature extractor
                        if args.feature_type == 'foundation_model':
                            feature_extractor = torch.hub.load('facebookresearch/dinov2', args.encoder_name)
                        elif args.feature_type == 'transfer_learning':
                            if args.encoder_name == 'resnet18':
                                feature_extractor = models.resnet18(pretrained=True)
                            elif args.encoder_name == 'resnet34':
                                feature_extractor = models.resnet34(pretrained=True)
                            elif args.encoder_name == 'resnet50':
                                feature_extractor = models.resnet50(pretrained=True)
                            feature_extractor.fc = nn.Identity() # Replace the classification layer with an identity function
                        elif args.feature_type == 'contrastive_learning':
                            pretrained_model = torch.load(f'trained model/ckpt_{args.dataset}_{args.encoder_name}.pth')
                            sd = {}
                            for ke in pretrained_model['model']:
                                nk = ke.replace('module.', '')
                                sd[nk] = pretrained_model['model'][ke]
                            feature_extractor = Encoder_cl(name=args.encoder_name)
                            feature_extractor.load_state_dict(sd, strict=False)

                        feature_extractor.to(device)
                        feature_extractor.eval()

                        # Extract the features
                        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=False)
                        testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False)

                        X_train, y_train = [], []
                        X_test, y_test = [], []

                        # Extract features from training data (training labels are noisy, load later)
                        for i, (inputs, labels) in enumerate(trainloader):
                            inputs = inputs.to(device)
                            with torch.no_grad():
                                features = feature_extractor(inputs)
                                features = torch.flatten(features, start_dim=1).cpu().numpy()  # Flatten the features
                            X_train.append(features)
                            y_train.append(labels)
                            if (i+1) % 50 == 0:  # Check if (i+1) is divisible by 10
                                print(f'Batch {i+1}/{len(trainloader)} of train data processed.')

                        # Extract features and labels from test data
                        for i, (inputs, labels) in enumerate(testloader):
                            inputs = inputs.to(device)
                            with torch.no_grad():
                                features = feature_extractor(inputs)
                                features = torch.flatten(features, start_dim=1).cpu().numpy()  # Flatten the features
                            X_test.append(features)
                            y_test.append(labels.numpy())
                            if (i+1) % 50 == 0:  # Check if (i+1) is divisible by 10
                                print(f'Batch {i+1}/{len(testloader)} of test data processed.')
                        X_train = np.concatenate(X_train, axis=0)
                        X_test = np.concatenate(X_test, axis=0)
                        y_train = np.concatenate(y_train, axis=0)
                        y_test = np.concatenate(y_test, axis=0)

                    if args.noise_mode == 'symmetric_flip_label':
                        T = generate_symmetric_noise_matrix(num_classes, args.symmetric_flip_prob)
                        print("T: ", T)
                        classes = range(num_classes)
                        for i, cl in enumerate(y_train):
                            # print("cl: ", cl)
                            # print("T[cl]: ", T[cl])
                            y_train[i] = np.random.choice(classes, p=T[cl])
                    elif args.noise_mode == "anti-diagonal":
                        T = generate_anti_diag_noise_matrix(num_classes, args.symmetric_flip_prob)

                        T = normalize_rows_to_one(T)

                        classes = range(num_classes)
                        for i, cl in enumerate(y_train):
                            #print("cl: ", cl)
                            #print("T[cl]: ", T[cl])
                            y_train[i] = np.random.choice(classes, p=T[cl])
                    #    y_train = torch.load(f'{dataset_path}/CIFAR-10_symmetric_{args.symmetric_flip_prob}.pt')
                    best_params = get_best_params(algorithm) # perform_cross_validation(X_train, y_train, algorithm) #  #
                    final_model = get_final_model(algorithm, best_params)
                    final_model.fit(X_train, y_train)

                    # 7. Apply the custom argmin prediction on the test set
                    y_pred_proba = final_model.predict_proba(X_test)
                    y_pred_argmin = np.argmin(y_pred_proba, axis=1)
                    y_pred_argmax = np.argmax(y_pred_proba, axis=1)

                    # Evaluate the custom prediction
                    argmin_accuracy = accuracy_score(y_test, y_pred_argmin)
                    argmax_accuracy = accuracy_score(y_test, y_pred_argmax)
                    print(f"Accuracy on test set with argmin prediction: {argmin_accuracy:.4f}")
                    print(f"Accuracy on test set with argmax prediction: {argmax_accuracy:.4f}")

                    write_metrics_to_file("experiment_results.txt", args, argmin_accuracy, argmax_accuracy, best_params)


if __name__ == '__main__':
    main()

