# load pretrained model from checkpoint
from utils.backbone import get_model
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
import torch
import torch.nn as nn
from copy import deepcopy
from utils.model_utils import ResNet_wrapper, Boolean, ResNetTails, ViTTails, Separable, mlp, infonce_lower_bound, ViTWrapper
from tqdm import tqdm
from time import time
import argparse
import random
from copy import deepcopy
from collections import defaultdict
import numpy as np
from utils.vit_loader import ViTLOAD, TV_ViTLOAD

# -0.04548719085
def seed_torch(seed):
    np.random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--layer', type=int, default=5)
    parser.add_argument('--ckpt', type=str, default='ResNet18_cifar10_class_4_5000_retrain0.pth')

    parser.add_argument('--train_mode', type=str, required=True) # mi, acc
    parser.add_argument('--test_mode', type=str, default='class') # class, sub_class, sample

    parser.add_argument('--class_idx', type=int, default=4)
    parser.add_argument('--class_idx_unlearn', type=int, default=1)
    parser.add_argument('--sub_class_name', type=str, default='baby')

    args = parser.parse_args()

    return args

def get_infonce(args, pretrained_model=None):
    
    print("CHECKPOINT: ")
    print(args.ckpt)
    
    if "ResNet18" in args.ckpt:
        final_size = 512
        model_name = 'ResNet18'
    elif "ResNet50" in args.ckpt:
        final_size = 2048
        model_name = 'ResNet50'
    elif "ViT" in args.ckpt:
        final_size = 768
        model_name = 'ViT'
        
    if 'cifar100' in args.ckpt:
        dataset_name = 'cifar100'
        num_classes = 100
        if args.test_mode == "class": 
            if args.layer == 1: epochs = 100
            else: epochs = 50
        elif args.test_mode == "sample": epochs = 100
        elif args.test_mode == "sub_class":
            if args.layer == 1: epochs = 200
            else: epochs = 100
        exclude_num_per_class = 50
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)

    elif 'cifar10' in args.ckpt:
        dataset_name = 'cifar10'
        num_classes = 10
        if args.test_mode == "class": 
            if args.layer == 1: epochs = 100
            else: epochs = 50
        elif args.test_mode == "sample": epochs = 100
        exclude_num_per_class = 500
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)

    elif 'tinyimagenet' in args.ckpt:
        dataset_name = 'tinyimagenet'
        num_classes = 200
        epochs = 1000 // args.class_idx_unlearn
        exclude_num_per_class = 50
        mean = (0.4802, 0.4481, 0.3975)
        std = (0.2302, 0.2265, 0.2262)

    elif 'imagenet' in args.ckpt:
        dataset_name = 'imagenet'
        num_classes = 1000
        if args.test_mode == "class":
            epochs = 50
        # epochs = 1000 // args.class_idx_unlearn
        exclude_num_per_class = 128
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)

    layer = args.layer
    # class 4 -> 0, rest -> 1
    
    if model_name == "ViT":
        if dataset_name == "imagenet":
            eval_model = ViTLOAD(ckpt=args.ckpt)
        elif dataset_name in ['cifar10', 'cifar100']:
            eval_model = TV_ViTLOAD(ckpt=args.ckpt, num_classes=num_classes)


        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])


        for name, param in eval_model.named_parameters():
            # block layer까지 freeze
            if 'patch_embed' in name or 'pos_embed' in name:
                param.requires_grad = False
            if 'blocks' in name and int(name.split('.')[1]) < layer:
                param.requires_grad = False

        for name, param in eval_model.named_parameters():
            if 'position_embeddings' in name:
                param.requires_grad = False
            if 'base.encoder.layer' in name and int(name.split('.')[3]) < layer:
                param.requires_grad = False

        # checking 
        for name, param in eval_model.named_parameters():
            print(name, param.requires_grad)

        intermediate_head = nn.Identity()
        

    elif model_name in ["ResNet18", "ResNet50"]:
        if pretrained_model == None:
            eval_model = get_model(model_name, num_classes=num_classes, ckpt_path=args.ckpt).to('cuda')
            init_model = get_model(model_name, num_classes=num_classes).to('cuda')
        else:
            eval_model = get_model(model_name, num_classes=num_classes, ckpt_path=args.ckpt).to('cuda')
            init_model = deepcopy(pretrained_model).to('cuda')

        if dataset_name != "imagenet":
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        eval_model = ResNet_wrapper(eval_model)
        intermediate_head = ResNetTails(init_model, layer+1).to('cuda')
        if model_name == "ResNet18" and layer==6: intermediate_head = nn.Identity()
        elif model_name == "ResNet50" and layer==7: intermediate_head = nn.Identity()
    
    dataset = ImageFolder(root = f'./dataset/{dataset_name}/train', transform=transform)

    if args.test_mode == "class":
        labels = dataset.targets

        class_idx = list(range(args.class_idx, args.class_idx+args.class_idx_unlearn))
        forget_idx = [i for i, label in enumerate(labels) if label in class_idx]

        rest_idx = [i for i, label in enumerate(labels) if label not in class_idx]
        # sample the same number of indices per class from the rest indices (128)

        if dataset_name in ["imagenet"]:
            # class_per_class => 5(1%, 1:1), 25(5%, 1:5), 100(20%, 1:20)
            ct = [0] * num_classes
            new_rest_idx = []
            for idx in rest_idx:
                if ct[labels[idx]] < class_per_class:
                    new_rest_idx.append(idx)
                    ct[labels[idx]] += 1

            rest_idx = new_rest_idx

        random.shuffle(forget_idx)
        random.shuffle(rest_idx)

        idx = forget_idx + rest_idx

        random.shuffle(idx)
        print("InfoNCE UNLEARNED CLASSES: ", class_idx)
        print("SAMPLED INDICES: ", len(idx), len(forget_idx), len(rest_idx))

        dataset = BinaryDataset(dataset, class_idx, idx)
        if dataset_name == "imagenet" and model_name=="ResNet50":
            dataloader = DataLoader(dataset, batch_size=45, shuffle=True, num_workers=8, pin_memory=True)
        else:
            print("BATCH SIZE: ", 128)
            dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

        if args.train_mode == "acc":
            test_dataset = ImageFolder(root = f'./dataset/{dataset_name}/test', transform=transform)
            test_labels = test_dataset.targets

            test_class_idx = list(range(args.class_idx, args.class_idx+args.class_idx_unlearn))
            test_forget_idx = [i for i, label in enumerate(test_labels) if label in test_class_idx]
            test_rest_idx = [i for i, label in enumerate(test_labels) if label not in test_class_idx]

            random.shuffle(test_rest_idx)
            test_rest_idx = test_rest_idx[:len(test_forget_idx)]
            test_idx = test_forget_idx + test_rest_idx
            random.shuffle(test_idx)
            print("TEST UNLEARNED CLASSES: ", test_class_idx)
            print(len(test_idx))

            test_dataset = BinaryDataset(test_dataset, test_class_idx, test_idx)
            test_dataloader = DataLoader(test_dataset, batch_size=500, shuffle=True, num_workers=4)

    
    elif args.test_mode == "sample":
        dataset = BinarySampleDataset(dataset, exclude_num_per_class=exclude_num_per_class, num_classes=num_classes)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

        if args.train_mode == "acc":
            test_dataset = ImageFolder(root = f'./dataset/{dataset_name}/test', transform=transform)
            test_dataset = BinarySampleDataset(test_dataset, exclude_num_per_class=exclude_num_per_class, num_classes=num_classes)
            test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=4)
            
    elif args.test_mode == "sub_class":
        dataset = BinarySubClassDataset(dataset, args.sub_class_name, transform=transform)
        dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    # I think this is the training loop
    print("====================================")
    print("Layer: ", layer)
    print("====================================")
    # save least loss
    if args.train_mode == 'mi':
        boolean_net = Boolean(1, 256, 1, 6).to('cuda')
        # print(intermediate_head)
        infonce_net = Separable(intermediate_head, final_size, dataset=dataset_name, test_mode=args.test_mode).to('cuda')

        boolean_optimizer = torch.optim.Adam(boolean_net.parameters(), lr=5e-4)

        if model_name == "ResNet18":
            if dataset_name == 'cifar10':
                if args.test_mode == "class": infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=2e-5)
                elif args.test_mode == "sample": infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=2e-6)
                elif args.test_mode == "sub_class": infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=2e-6)
            elif dataset_name == 'cifar100':
                if args.test_mode == "class": infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=1e-5)
                elif args.test_mode == "sample": infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=1e-6)
                elif args.test_mode == "sub_class": infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=1e-6)
        elif model_name == "ResNet50":
            if dataset_name == 'cifar10':
                infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=1e-5)
            elif dataset_name == 'cifar100':
                infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=1e-5)
            elif dataset_name == 'imagenet':
                infonce_optimizer = torch.optim.Adam(infonce_net.parameters(), lr=1e-5)
                
        elif model_name == "ViT":
            if dataset_name == "cifar10":
                infonce_optimizer = torch.optim.Adam(list(infonce_net.parameters()) + list(eval_model.parameters()), lr=1e-5)
            elif dataset_name == "cifar100":
                infonce_optimizer = torch.optim.Adam(list(infonce_net.parameters()) + list(eval_model.parameters()), lr=1e-5)
            elif dataset_name == "imagenet":
                infonce_optimizer = torch.optim.Adam(list(infonce_net.parameters()) + list(eval_model.parameters()), lr=2e-6)

        least_loss = 1000000
        start = time()
    
        print("Learning rate", infonce_optimizer.param_groups[0]['lr'])
        answers = []
        for epoch in range(epochs):
            end = time()
            print(f"Epoch: {epoch+1}/{epochs}, Elapsed Time: {end-start:.2f}, ETA: {(end-start)*(epochs-epoch):.2f}")
            start = time()
            estimated_MI = defaultdict(list)
            valid_MI = defaultdict(list)
            # Training Phase
            for data, target in tqdm(dataloader):
                data, target = data.to('cuda'), target.to('cuda')

                if model_name == "ViT":
                    _, inp = eval_model(data, get_embeddings=True)

                elif model_name in ["ResNet18", "ResNet50"]:
                    _, features = eval_model(data, get_all_features=True)
                    inp = features[f'l{layer}']

                target = target.unsqueeze(1).float()

                y_encoded = boolean_net(target)
                scores = infonce_net(inp, y_encoded)
                loss = -infonce_lower_bound(scores)

                boolean_optimizer.zero_grad()
                infonce_optimizer.zero_grad()

                loss.backward()
                # print(-loss)
                boolean_optimizer.step()
                infonce_optimizer.step()       
            
            # Evaluation Phase
            with torch.no_grad():
                for data, target in dataloader:
                    data, target = data.to('cuda'), target.to('cuda')
                    if model_name == "ViT":
                        _, inp = eval_model(data, get_embeddings=True)

                    elif model_name in ["ResNet18", "ResNet50"]:
                        _, features = eval_model(data, get_all_features=True)
                        inp = features[f'l{layer}']


                    target = target.unsqueeze(1).float()

                    y_encoded = boolean_net(target)
                    scores = infonce_net(inp, y_encoded)
                    loss = -infonce_lower_bound(scores)
                    estimated_MI[epoch].append(-loss.item())
            
            print(f"Epoch: {epoch+1}/{epochs}, MI: {np.mean(np.array(estimated_MI[epoch])):.4f}")
            answers.append(np.mean(np.array(estimated_MI[epoch])))

        answer = -least_loss

    elif args.train_mode == 'acc':

        criterion = nn.CrossEntropyLoss()
        proj = mlp(final_size, 512, 256, 2).to('cuda')
        head = mlp(256, 256, 2, 1).to('cuda')

        optimizer = torch.optim.Adam(list(intermediate_head.parameters()) + list(proj.parameters()) + list(head.parameters()), lr=1e-5)
       
        for epoch in range(epochs):
            num, correct = 0, 0
            print(f"Epoch: {epoch+1}/{epochs}")
            for data, target in dataloader:
                data, target = data.to('cuda'), target.to('cuda')
                # _, features = eval_model(data, get_all_features=True)
                out = eval_model(data)
                print("out", out.shape)
                inp = features[f'l{layer}']
                out = intermediate_head(inp)
                out = proj(out)
                out = head(out)
            
                loss = criterion(out, target.long())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                num += len(target)
                correct += (out.argmax(1) == target.long()).sum().item()

            # evaluate the model
            with torch.no_grad():
                test_num, test_correct = 0, 0
                for data, target in test_dataloader:
                    data, target = data.to('cuda'), target.to('cuda')
                    _, features = eval_model(data, get_all_features=True)
                    inp = features[f'l{layer}']
                    out = intermediate_head(inp)
                    out = proj(out)
                    out = head(out)
                
                    test_num += len(target)
                    test_correct += (out.argmax(1) == target.long()).sum().item()

            print(f"Accuracy: {correct / num * 100:.2f}% (train), {test_correct / test_num * 100:.2f}% (test)")

    print(max(answers))
    return max(answers)
            

class BinaryDataset(Dataset):
    def __init__(self, dataset, Uclasses, idx):
        self.dataset = dataset
        self.Uclasses = Uclasses
        print("CLASS IDX: ", Uclasses)
        self.idx = idx
        random.shuffle(self.idx) 

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, idx):
        actual_idx = self.idx[idx]
        data, target = self.dataset[actual_idx]
        # target = 0. if target == self.class_idx else 1.
        target = 0. if target in self.Uclasses else 1.
        return data, target

class BinarySampleDataset(Dataset):
    def __init__(self, dataset, exclude_num_per_class=500, num_classes=10):
        self.dataset = dataset
        self.classes = [0] * num_classes
        self.exclude_num_per_class = exclude_num_per_class
        self.indices0, self.indices1 = self._filter_indices()
        random.shuffle(self.indices1)
        # print("SAMPLED INDICES: ", len(self.indices0), len(self.indices1))
        print("FORGET:", len(self.indices0), "REMAIN:", len(self.indices1))
        self.indices = self.indices0 + self.indices1

    def _filter_indices(self):
        indices0, indices1 = [], []
        for idx, (_, class_idx) in enumerate(tqdm(self.dataset.samples)):
            if self.classes[class_idx] < self.exclude_num_per_class:
                indices0.append(idx)
                self.classes[class_idx] += 1
            else: indices1.append(idx)
        return indices0, indices1

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        data, target = self.dataset[actual_idx]
        target = 0. if actual_idx in self.indices0 else 1.
        return data, target
    
class BinarySubClassDataset(Dataset):
    def __init__(self, dataset, sub_class_name, transform=None):
        self.dataset = dataset
        self.coarse_map = {
            0:[4, 30, 55, 72, 95],
            1:[1, 32, 67, 73, 91],
            2:[54, 62, 70, 82, 92],
            3:[9, 10, 16, 28, 61],
            4:[0, 51, 53, 57, 83],
            5:[22, 39, 40, 86, 87],
            6:[5, 20, 25, 84, 94],
            7:[6, 7, 14, 18, 24],
            8:[3, 42, 43, 88, 97],
            9:[12, 17, 37, 68, 76],
            10:[23, 33, 49, 60, 71],
            11:[15, 19, 21, 31, 38],
            12:[34, 63, 64, 66, 75],
            13:[26, 45, 77, 79, 99],
            14:[2, 11, 35, 46, 98],
            15:[27, 29, 44, 78, 93],
            16:[36, 50, 65, 74, 80],
            17:[47, 52, 56, 59, 96],
            18:[8, 13, 48, 58, 90],
            19:[41, 69, 81, 85, 89]
        }
        self._class = dataset.class_to_idx[sub_class_name]
        for key, value in self.coarse_map.items():
            if self._class in value:
                self.coarse_class = value
                break
        print("COARSE CLASS", self.coarse_class)
        self.indices0, self.indices1 = self._filter_indices()
        self.transform = transform
        self.indices = self.indices0 + self.indices1

    def _filter_indices(self):
        # Efficiently filter indices based on metadata/annotations
        indices0, indices1 = [], []
        for idx, (_, class_idx) in enumerate(self.dataset.samples):
            if class_idx == self._class: indices0.append(idx)
            elif class_idx in self.coarse_class: indices1.append(idx)
            
            
        print("Forget", len(indices0), "Remain", len(indices1))
        return indices0, indices1

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        data, target = self.dataset[actual_idx]
        target = 0. if actual_idx in self.indices0 else 1.
        return data, target
    
if __name__ == '__main__':
    seed_torch(42)
    args = get_args()
    get_infonce(args, pretrained_model=None)
    

