from utils.adv_generator import inf_generator
from tqdm import tqdm
import torch
from torch import nn
from methods.method import Method
from utils.eval import eval
from copy import deepcopy
import torch.nn.functional as F
from utils.backbone import get_model
from utils.utils import get_logits
from utils.supcon_loss import SupConLoss
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
from itertools import cycle
from collections import defaultdict
from utils.data import IntraClassShuffleSampler
import torch.optim as optim
import torchvision.transforms as transforms

class COLA(Method):

    def set_hyperparameters(self, args):
        self.contrast_epochs = 10
        self.finetune_epochs = 10
        if args.model_name == "ResNet18":
            if args.data_name == "cifar10":
                if args.test_mode == "class":
                    self.contrast_scheduler_last = 1e-5
                    self.ft_lr = 5e-6
                    self.ft_scheduler_last = 1e-6
                elif args.test_mode == "sample":
                    self.contrast_scheduler_last = 2e-4
                    self.ft_lr = 1e-4
                    self.ft_scheduler_last = 1e-5
            elif args.data_name == "cifar100":
                if args.test_mode == "class":
                    self.contrast_scheduler_last = 2e-4
                    self.ft_lr = 5e-6
                    self.ft_scheduler_last = 1e-6
                elif args.test_mode == "sample":
                    self.contrast_scheduler_last = 2e-4
                    self.ft_lr = 2e-5
                    self.ft_scheduler_last = 2e-6
        elif args.model_name == "ResNet50":
            if args.data_name == "cifar10":
                self.contrast_scheduler_last = 1e-4
                self.ft_lr = 1e-5
                self.ft_scheduler_last = 5e-6
            elif args.data_name == "cifar100":
                self.contrast_scheduler_last = 1e-4
                self.ft_lr = 1e-5
                self.ft_scheduler_last = 5e-6
            elif args.data_name == "imagenet":
                self.contrast_scheduler_last = 1e-5
                self.ft_lr = 1e-5
                self.ft_scheduler_last = 1e-5
                self.contrast_epochs = 1
                self.finetune_epochs = 2
        elif args.model_name == "ViT":
            if args.data_name == "cifar10":
                self.contrast_scheduler_last = 1e-5
                self.ft_lr = 5e-5
                self.ft_scheduler_last = 1e-5
            elif args.data_name == "cifar100":
                self.contrast_scheduler_last = 2e-4
                self.ft_lr = 5e-5
                self.ft_scheduler_last = 1e-5
            elif args.data_name == "imagenet":
                self.contrast_scheduler_last = 2e-5
                self.ft_scheduler_last = 1e-5
                self.ft_lr = 5e-5
                self.contrast_epochs = 1
                self.finetune_epochs = 2

    def unlearn(self, model, loaders, args):

        device = args.device
        supcon = SupConLoss()
        
        t_model = deepcopy(model)
        
        # make new dataset with autoaugment
        if args.retain_ratio < 1:
            sub_train_set = torch.utils.data.Subset(self.train_remain_set, np.random.choice(len(self.train_remain_set), int(len(self.train_remain_set) * args.retain_ratio), replace=False))
            print(f"Number of samples in the subset: {len(sub_train_set)}, retain ratio: {args.retain_ratio}")
            self.train_remain_loader = torch.utils.data.DataLoader(sub_train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

        sampler = IntraClassShuffleSampler(self.train_forget_set, batch_size=args.forget_batch_size)
        # make drop last true for train_remain_loader
        train_remain_loader = torch.utils.data.DataLoader(self.train_remain_set, batch_size=args.remain_batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
        sorted_train_forget_loader = torch.utils.data.DataLoader(self.train_forget_set, batch_size=args.forget_batch_size, sampler=sampler, num_workers=4, pin_memory=True, drop_last=True)
        finetune_train_remain_loader = torch.utils.data.DataLoader(self.train_remain_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

        train_adjacent_loader = loaders['train_adjacent_loader']

        contrast_iter = len(self.train_remain_loader) * self.contrast_epochs
        finetune_iter = len(self.train_remain_loader) * self.finetune_epochs

        model.train()
        
        if args.test_mode == "class":
            self.repul_supcon(model, train_remain_loader, sorted_train_forget_loader, supcon, contrast_iter, device, args)
        elif args.test_mode == "sample":
            self.random_repel(model, train_remain_loader, sorted_train_forget_loader, supcon, contrast_iter, device, args)

        self.finetune(model, finetune_train_remain_loader, self.finetune_epochs, finetune_iter, device)

        return model

    # for class-unlearn
    def repul_supcon(self, model, train_remain_loader, sorted_train_forget_loader, supcon, contrast_iter, device, args):
        ct = 0
        criterion = nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)
        scheduler = CosineAnnealingLR(optimizer, T_max=contrast_iter, eta_min=self.contrast_scheduler_last)

        num_iters = len(train_remain_loader) * self.contrast_epochs
        for epoch in range(1, self.contrast_epochs+1):
            print(f"Epoch: {epoch}, lr: {optimizer.param_groups[0]['lr']}")
            for idx, (batch_retrain, batch_forget) in enumerate(zip(train_remain_loader, cycle(sorted_train_forget_loader))):
                ct += 1
                mi_minimize = 0
                
                x_remain, y_remain = batch_retrain
                x_forget, y_forget = batch_forget
                
                x_remain, y_remain = x_remain.to(device), y_remain.to(device)
                
                r_logits, embeddings_remain = model(x_remain, get_embeddings=True)

                norm_embeddings_remain = F.normalize(embeddings_remain, dim=1)

                remain_loss = supcon(norm_embeddings_remain.unsqueeze(1), y_remain)
                
                model.zero_grad()
                optimizer.zero_grad()
                loss = remain_loss
                print(f"[{ct} / {num_iters}] Loss: {loss.item()}")
                loss.backward()
                optimizer.step()
                scheduler.step()

    def random_repel(self, model, train_remain_loader, sorted_train_forget_loader, supcon, contrast_iter, device, args):
        ct = 0
        criterion = nn.CrossEntropyLoss()
        optimizer = self.get_optimizer(model)
        scheduler = CosineAnnealingLR(optimizer, T_max=contrast_iter, eta_min=self.contrast_scheduler_last)

        for epoch in range(1, self.contrast_epochs+1):
            print(f"Epoch: {epoch}, lr: {optimizer.param_groups[0]['lr']}")
            for idx, (batch_retrain, batch_forget) in enumerate(zip(train_remain_loader, cycle(sorted_train_forget_loader))):
                ct += 1
                mi_minimize = 0
                
                x_remain, y_remain = batch_retrain
                x_forget, y_forget = batch_forget
                
                x_remain, y_remain = x_remain.to(device), y_remain.to(device)
                x_forget, y_forget = x_forget.to(device), y_forget.to(device)
                
                r_logits, embeddings_remain = model(x_remain, get_embeddings=True)
                f_logits, embeddings_forget = model(x_forget, get_embeddings=True)

                # index of the largest value except for the true label
                top2_y_forget = f_logits.topk(2, dim=1)[1]
                
                y_forget = torch.where(top2_y_forget[:, 0] == y_forget, top2_y_forget[:, 1], top2_y_forget[:, 0])

                norm_embeddings_remain = F.normalize(embeddings_remain, dim=1)
                norm_embeddings_forget = F.normalize(embeddings_forget, dim=1)
                
                norm_embeddings = torch.cat([norm_embeddings_remain, norm_embeddings_forget], dim=0)
                y = torch.cat([y_remain, y_forget], dim=0)

                remain_loss = supcon(norm_embeddings_remain.unsqueeze(1), y_remain)
                
                model.zero_grad()
                optimizer.zero_grad()
                loss = remain_loss
                loss.backward()
                optimizer.step()
                scheduler.step()
    

    def finetune(self, model, finetune_train_remain_loader, finetune_epochs, finetune_iter, device):
        
        optimizer = optim.Adam(model.parameters(), lr=self.ft_lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=finetune_iter, eta_min=self.ft_scheduler_last)
        criterion = nn.CrossEntropyLoss()
        

        for epoch in range(1, self.finetune_epochs + 1):
            for batch_retrain in tqdm(finetune_train_remain_loader):
                x_remain, y_remain = batch_retrain
                x_remain, y_remain = x_remain.to(device), y_remain.to(device)

                logits_remain, embeddings_remain = model(x_remain, get_embeddings=True)

                ce_loss = criterion(get_logits(logits_remain), y_remain)
                model.zero_grad()
                optimizer.zero_grad()

                loss = ce_loss
                loss.backward()
                optimizer.step()
                scheduler.step()