import os
import copy
import torch
import torch.optim
import torch.nn as nn
import torch.utils.data
from collections import OrderedDict, Counter

import utils
import unlearn
import pruner
from trainer import validate
import evaluation
import numpy as np

import arg_parser
from termcolor import cprint
import time


def main():
    args = arg_parser.parse_args()

    if torch.cuda.is_available():
        torch.cuda.set_device(int(args.gpu))
        device = torch.device(f"cuda:{int(args.gpu)}")
    else:
        device = torch.device("cpu")

    os.makedirs(args.save_dir, exist_ok=True)
    if args.seed == -1:
        seeds = [2, 10, 100]
    else:
        seeds = [args.seed]
    resultss = []
    for seed in seeds:
        results = [0, 0, 0, 0, 0]
        # if args.seed:
        #     utils.setup_seed(args.seed)
        # seed = args.seed
        utils.setup_seed(seed)
        args.seed = seed
        # prepare dataset
        (
            model,
            train_loader_full,
            val_loader,
            test_loader,
            marked_loader,
        ) = utils.setup_model_dataset(args)
        model.cuda()

        def replace_loader_dataset(
            dataset, batch_size=args.batch_size, seed=1, shuffle=True, ratio = 1
        ):
            if ratio > 1:
                dataset_extended = torch.utils.data.ConcatDataset([dataset] * int(ratio + 1))
            else:
                dataset_extended = dataset
            utils.setup_seed(seed)
            return torch.utils.data.DataLoader(
                dataset_extended,
                batch_size=batch_size,
                num_workers=4,
                pin_memory=True,
                shuffle=shuffle,
            )

        forget_dataset = copy.deepcopy(marked_loader.dataset)
        if args.dataset == "svhn":
            try:
                marked = forget_dataset.targets < 0
            except:
                marked = forget_dataset.labels < 0
            forget_dataset.data = forget_dataset.data[marked]
            try:
                forget_dataset.targets = -forget_dataset.targets[marked] - 1
            except:
                forget_dataset.labels = -forget_dataset.labels[marked] - 1
            print(Counter(forget_dataset.labels))
            #TODO: forget dataset concate

            print(len(forget_dataset))
            retain_dataset = copy.deepcopy(marked_loader.dataset)
            try:
                marked = retain_dataset.targets >= 0
            except:
                marked = retain_dataset.labels >= 0
            retain_dataset.data = retain_dataset.data[marked]
            try:
                retain_dataset.targets = retain_dataset.targets[marked]
            except:
                retain_dataset.labels = retain_dataset.labels[marked]
            retain_loader = replace_loader_dataset(
                retain_dataset, seed=seed, shuffle=True
            )
            # if args.p < 1:
            ratio = len(retain_dataset) // len(forget_dataset)
            # else:
            #     ratio = 1
            forget_loader = replace_loader_dataset(
                forget_dataset, seed=seed, shuffle=True, ratio=ratio
            )
            print(len(retain_dataset))
            assert len(forget_dataset) + len(retain_dataset) == len(
                train_loader_full.dataset
            )
        elif args.dataset == "celebA_smile":
            indices = torch.where(forget_dataset.attr[:, 31] == -1)[0]
            forget_dataset = torch.utils.data.Subset(train_loader_full.dataset, indices)
            #TODO: forget dataset concate
            # forget_loader = replace_loader_dataset(
            #     forget_dataset, seed=seed, shuffle=True
            # )
            print(len(forget_dataset))
            retain_dataset = copy.deepcopy(marked_loader.dataset)
            indices = torch.where(retain_dataset.attr[:, 31] != -1)[0]
            retain_dataset = torch.utils.data.Subset(train_loader_full.dataset, indices)
            retain_loader = replace_loader_dataset(
                retain_dataset, seed=seed, shuffle=True
            )
            # if args.p < 1:
            ratio = len(retain_dataset) // len(forget_dataset)
            # else:
            #     ratio = 1
            forget_loader = replace_loader_dataset(
                forget_dataset, seed=seed, shuffle=True, ratio=ratio
            )
            print(len(retain_dataset))
            assert len(forget_dataset) + len(retain_dataset) == len(
                train_loader_full.dataset
            )
        else:
            try:
                marked = forget_dataset.targets < 0
                forget_dataset.data = forget_dataset.data[marked]
                forget_dataset.targets = -forget_dataset.targets[marked] - 1
                print(Counter(forget_dataset.targets))
                # if args.dataset == "cifar100-20":
                #     forget_dataset.targets = np.array(
                #         [
                #             forget_dataset.index_mapping[i]
                #             for i in forget_dataset.targets
                #         ]
                #     )
                #TODO: forget dataset concate
                # forget_loader = replace_loader_dataset(
                #     forget_dataset, seed=seed, shuffle=True
                # )
                print(len(forget_dataset))
                retain_dataset = copy.deepcopy(marked_loader.dataset)
                marked = retain_dataset.targets >= 0
                retain_dataset.data = retain_dataset.data[marked]
                # Remaining dataset w/o forget data
                retain_dataset.targets = retain_dataset.targets[marked]
                retain_loader = replace_loader_dataset(
                    retain_dataset, seed=seed, shuffle=True
                )
                # if args.p < 1:
                ratio = len(retain_dataset) // len(forget_dataset)
                # else:
                    # ratio = 1
                forget_loader = replace_loader_dataset(
                    forget_dataset, seed=seed, shuffle=True, ratio=ratio
                )
                print(len(retain_dataset))
                assert len(forget_dataset) + len(retain_dataset) == len(
                    train_loader_full.dataset
                )
            except:
                marked = forget_dataset.targets < 0
                forget_dataset.imgs = forget_dataset.imgs[marked]
                forget_dataset.targets = -forget_dataset.targets[marked] - 1
                #TODO: forget dataset concate
                # forget_loader = replace_loader_dataset(
                #     forget_dataset, seed=seed, shuffle=True
                # )
                print(len(forget_dataset))
                retain_dataset = copy.deepcopy(marked_loader.dataset)
                marked = retain_dataset.targets >= 0
                retain_dataset.imgs = retain_dataset.imgs[marked]
                retain_dataset.targets = retain_dataset.targets[marked]
                retain_loader = replace_loader_dataset(
                    retain_dataset, seed=seed, shuffle=True
                )
                # if args.p < 1:
                ratio = len(retain_dataset) // len(forget_dataset)
                # else:
                #     ratio = 1
                forget_loader = replace_loader_dataset(
                    forget_dataset, seed=seed, shuffle=True, ratio=ratio
                )
                print(len(retain_dataset))
                assert len(forget_dataset) + len(retain_dataset) == len(
                    train_loader_full.dataset
                )
        # label_id = dict()
        # sample_indcies = []
        # for i in np.unique(retain_dataset.targets):
        #     label_id[i] = np.where(retain_dataset.targets == i)

        def print_red_on_cyan(x):
            return cprint(x, "red", "on_cyan")

        unlearn_data_loaders = OrderedDict(
            retain=retain_loader,
            forget=forget_loader,
            val=val_loader,
            test=test_loader,
            retain_dataset=retain_dataset,
        )

        criterion = nn.CrossEntropyLoss()

        evaluation_result = None

        if args.resume:
            checkpoint = unlearn.load_unlearn_checkpoint(model, device, args)

        if args.resume and checkpoint is not None:
            model, evaluation_result = checkpoint
        else:
            checkpoint = torch.load(args.mask, map_location=device)
            if "state_dict" in checkpoint.keys():
                checkpoint = checkpoint["state_dict"]
            current_mask = pruner.extract_mask(checkpoint)
            pruner.prune_model_custom(model, current_mask)
            pruner.check_sparsity(model)

            if args.unlearn != "retrain":
                model.load_state_dict(checkpoint, strict=True)

            # if True:
            #     # accuracy_before_Unlearn = {}
            #     for name, loader in unlearn_data_loaders.items():
            #         if 'dataset' in name:
            #             continue
            #         utils.dataset_convert_to_test(loader.dataset, args)
            #         val_acc = validate(loader, model, criterion, args)
            #         # accuracy_before_Unlearn[name] = val_acc
            #         print(f"{name} acc: {val_acc}")
            #         if 'forget' in name:
            #             print_red_on_cyan(f'UA_before_Unlearn: {100-val_acc}')
            #             # results[0] = 100 - val_acc
            #         elif 'retain' in name:
            #             print_red_on_cyan(f'RA_before_unlearn: {val_acc}')
            #             # results[2] = val_acc
            #         elif 'test' in name:
            #             print_red_on_cyan(f'TA_before_unlearn: {val_acc}')
            # results[3] = val_acc

            # val_acc = validate(unlearn_data_loaders['retain'], model, criterion, args)
            # print_red_on_cyan(f'Retain before unlearn: {val_acc}')

            # val_acc = validate(unlearn_data_loaders['retain'], model, criterion, args)
            # print_red_on_cyan(f'Retain before unlearn: {val_acc}')

            unlearn_method = unlearn.get_unlearn_method(args.unlearn)
            unlearn_start_time = time.time()
            unlearn_method(unlearn_data_loaders, model, criterion, args)
            finish_time = time.time()
            print_red_on_cyan(f"Time: {finish_time - unlearn_start_time}")
            results[-1] = finish_time - unlearn_start_time
            unlearn.save_unlearn_checkpoint(model, None, args)

        if evaluation_result is None:
            evaluation_result = {}

        if "accuracy" not in evaluation_result:
            accuracy = {}
            for name, loader in unlearn_data_loaders.items():
                if "dataset" in name:
                    continue
                utils.dataset_convert_to_test(loader.dataset, args)
                val_acc = validate(loader, model, criterion, args)
                accuracy[name] = val_acc
                print(f"{name} acc: {val_acc}")
                if "forget" in name:
                    print_red_on_cyan(f"UA: {100-val_acc}")
                    results[0] = 100 - val_acc
                elif "retain" in name:
                    print_red_on_cyan(f"RA: {val_acc}")
                    results[2] = val_acc
                elif "test" in name:
                    print_red_on_cyan(f"TA: {val_acc}")
                    results[3] = val_acc

            evaluation_result["accuracy"] = accuracy
            unlearn.save_unlearn_checkpoint(model, evaluation_result, args)

        for deprecated in ["MIA", "SVC_MIA", "SVC_MIA_forget"]:
            if deprecated in evaluation_result:
                evaluation_result.pop(deprecated)

        """forget efficacy MIA:
            in distribution: retain
            out of distribution: test
            target: (, forget)"""
        if "SVC_MIA_forget_efficacy" not in evaluation_result:
            test_len = len(test_loader.dataset)
            forget_len = len(forget_dataset)
            retain_len = len(retain_dataset)

            utils.dataset_convert_to_test(retain_dataset, args)
            utils.dataset_convert_to_test(forget_loader, args)
            utils.dataset_convert_to_test(test_loader, args)

            shadow_train = torch.utils.data.Subset(
                retain_dataset, list(range(test_len))
            )
            shadow_train_loader = torch.utils.data.DataLoader(
                shadow_train,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=4,
                pin_memory=True,
            )

            evaluation_result["SVC_MIA_forget_efficacy"] = evaluation.SVC_MIA(
                shadow_train=shadow_train_loader,
                shadow_test=test_loader,
                target_train=None,
                target_test=forget_loader,
                model=model,
            )
            unlearn.save_unlearn_checkpoint(model, evaluation_result, args)
            results[1] = (
                evaluation_result["SVC_MIA_forget_efficacy"]["confidence"] * 100
            )

        # """training privacy MIA:
        #     in distribution: retain
        #     out of distribution: test
        #     target: (retain, test)"""
        # if "SVC_MIA_training_privacy" not in evaluation_result:
        #     test_len = len(test_loader.dataset)
        #     retain_len = len(retain_dataset)
        #     num = test_len // 2

        #     utils.dataset_convert_to_test(retain_dataset, args)
        #     utils.dataset_convert_to_test(forget_loader, args)
        #     utils.dataset_convert_to_test(test_loader, args)

        #     shadow_train = torch.utils.data.Subset(retain_dataset, list(range(num)))
        #     target_train = torch.utils.data.Subset(
        #         retain_dataset, list(range(num, retain_len))
        #     )
        #     shadow_test = torch.utils.data.Subset(test_loader.dataset, list(range(num)))
        #     target_test = torch.utils.data.Subset(
        #         test_loader.dataset, list(range(num, test_len))
        #     )

        #     shadow_train_loader = torch.utils.data.DataLoader(
        #         shadow_train, batch_size=args.batch_size, shuffle=False
        #     )
        #     shadow_test_loader = torch.utils.data.DataLoader(
        #         shadow_test, batch_size=args.batch_size, shuffle=False
        #     )

        #     target_train_loader = torch.utils.data.DataLoader(
        #         target_train, batch_size=args.batch_size, shuffle=False
        #     )
        #     target_test_loader = torch.utils.data.DataLoader(
        #         target_test, batch_size=args.batch_size, shuffle=False
        #     )

        #     evaluation_result["SVC_MIA_training_privacy"] = evaluation.SVC_MIA(
        #         shadow_train=shadow_train_loader,
        #         shadow_test=shadow_test_loader,
        #         target_train=target_train_loader,
        #         target_test=target_test_loader,
        #         model=model,
        #     )
        #     unlearn.save_unlearn_checkpoint(model, evaluation_result, args)

        unlearn.save_unlearn_checkpoint(model, evaluation_result, args)
        print(">>>>>>>>>> Seed we use is: ", seed)
        print_red_on_cyan("\t".join(["{:.3f}".format(x) for x in results]))
        resultss.append("\t".join(["{:.3f}".format(x) for x in results]))
        with open(f'{args.unlearn}_{args.dataset}_{args.forgetting_mode}_{str(args.p)}_{str(args.smooth_rate)}_{str(args.seed)}.txt', 'w') as f:
            # f.write(f'{r}, {finish_time - unlearn_start_time}')
            # f.write("\t".join(["{:.3f}".format(x) for x in results]))
            # f.write("\n")
            f.write(f'{results[0]}, {results[1]}, {results[2]}, {results[3]}, {finish_time - unlearn_start_time}')
    for r in resultss:
        print_red_on_cyan(r)


if __name__ == "__main__":
    main()
