import os
import logging
import time
import tqdm
import copy
import pickle 
import numpy as np
import torch
import torch.nn as nn
from models.diffusion import Conditional_Model
from models.ema import EMAHelper
from functions import get_optimizer, cycle, create_class_labels
from functions.losses import loss_registry_conditional
from datasets import get_dataset, data_transform, inverse_data_transform, all_but_one_class_path_dataset, groups_for_one_dataset, groups_for_one_path_dataset

import torchvision.utils as tvu
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision

from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora,get_lora_state_dict

def torch2hwcuint8(x, clip=False):
    if clip:
        x = torch.clamp(x, -1, 1)
    x = (x + 1.0) / 2.0
    return x


def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


class Diffusion(object):
    def __init__(self, args, config):
        self.args = args
        self.config = config
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.model_var_type = config.model.var_type
        betas = get_beta_schedule(
            beta_schedule=config.diffusion.beta_schedule,
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
        )
        betas = self.betas = torch.from_numpy(betas).float().to(self.device)
        self.num_timesteps = betas.shape[0]

        alphas = 1.0 - betas
        alphas_cumprod = alphas.cumprod(dim=0)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1).to(self.device), alphas_cumprod[:-1]], dim=0
        )
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        if self.model_var_type == "fixedlarge":
            self.logvar = betas.log()
            # torch.cat(
            # [posterior_variance[1:2], betas[1:]], dim=0).log()
        elif self.model_var_type == "fixedsmall":
            self.logvar = posterior_variance.clamp(min=1e-20).log()
            
    def save_fim(self):
        
        args, config = self.args, self.config
        bs = torch.cuda.device_count() # process 1 sample per GPU, so bs == number of gpus
        fim_dataset = ImageFolder(os.path.join(args.ckpt_folder, "class_samples"), 
                                  transform=transforms.ToTensor())
        fim_loader = DataLoader(fim_dataset, batch_size=bs, 
                                num_workers=config.data.num_workers, shuffle=True)

        print("Loading checkpoints {}".format(args.ckpt_folder))
        model = Conditional_Model(self.config)
        states = torch.load(
            os.path.join(self.args.ckpt_folder, "ckpts/ckpt.pth"),
            map_location=self.device,
        )
        model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(states[0], strict=True)
        model.eval()
    
        # calculate FIM 
        fisher_dict = {}
        fisher_dict_temp_list = [{} for _ in range(bs)]

        for name, param in model.named_parameters():
            fisher_dict[name] = param.data.clone().zero_()
            
            for i in range(bs):
                fisher_dict_temp_list[i][name] = param.data.clone().zero_()
        
        # calculate Fisher information diagonals
        for step, data in enumerate(tqdm.tqdm(fim_loader, desc="Calculating Fisher information matrix")):
            
            x, c = data
            x, c = x.to(self.device), c.to(self.device)
        
            b = self.betas
            ts = torch.chunk(torch.arange(0, self.num_timesteps), args.n_chunks)
            
            for _t in ts:
                for i in range(len(_t)):
                    e = torch.randn_like(x)
                    t = torch.tensor([_t[i]]).expand(bs).to(self.device)
                    
                    # keepdim=True will return loss of shape [bs], so gradients across batch are NOT averaged yet
                    if i == 0:
                        loss = loss_registry_conditional[config.model.type](model, x, t, c, e, b, keepdim=True)
                    else:
                        loss += loss_registry_conditional[config.model.type](model, x, t, c, e, b, keepdim=True)

                # store first-order gradients for each sample separately in temp dictionary
                # for each timestep chunk
                for i in range(bs):
                    model.zero_grad()
                    if i != len(loss) - 1:
                        loss[i].backward(retain_graph=True)
                    else:
                        loss[i].backward()
                    for name, param in model.named_parameters():
                        fisher_dict_temp_list[i][name] += param.grad.data
                del loss
            
            # after looping through all 1000 time steps, we can now aggregrate each individual sample's gradient and square and average
            for name, param in model.named_parameters():
                for i in range(bs):
                    fisher_dict[name].data += (fisher_dict_temp_list[i][name].data ** 2) / len(fim_loader.dataset)
                    fisher_dict_temp_list[i][name] = fisher_dict_temp_list[i][name].clone().zero_()

            if (step+1) % config.training.save_freq == 0:
                with open(os.path.join(args.ckpt_folder,'fisher_dict.pkl'), 'wb') as f:
                    pickle.dump(fisher_dict, f)
                    
        # save at the end
        with open(os.path.join(args.ckpt_folder,'fisher_dict.pkl'), 'wb') as f:
            pickle.dump(fisher_dict, f)

    def train(self):
        args, config = self.args, self.config
        D_train_loader = get_dataset(args, config)
        D_train_iter = cycle(D_train_loader)
        
        model = Conditional_Model(config)

        optimizer = get_optimizer(self.config, model.parameters())
        model.to(self.device)
        model = torch.nn.DataParallel(model)
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None
        
        model.train()
        
        start = time.time()
        for step in range(0, self.config.training.n_iters):

            model.train()
            x, c = next(D_train_iter)
            n = x.size(0)
            x = x.to(self.device)
            x = data_transform(self.config, x)
            e = torch.randn_like(x)
            b = self.betas

            # antithetic sampling
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            loss = loss_registry_conditional[config.model.type](model, x, t, c, e, b)
            
            if (step+1) % self.config.training.log_freq  == 0:
                end = time.time()
                logging.info(
                    f"step: {step}, loss: {loss.item()}, time: {end-start}"
                )
                start = time.time()
                
            optimizer.zero_grad()
            loss.backward()

            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(model)

            if (step+1) % self.config.training.snapshot_freq == 0:
                states = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    step,
                ]
                if self.config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(
                    states,
                    os.path.join(self.config.ckpt_dir, "ckpt.pth"),
                )
                #torch.save(states, os.path.join(self.config.ckpt_dir, "ckpt_latest.pth"))

                test_model = ema_helper.ema_copy(model) if self.config.model.ema else copy.deepcopy(model)
                test_model.eval()
                self.sample_visualization(test_model, step, args.cond_scale)
                del test_model
                
    def select_outlier(self, p1, p2):
        n = (p1 - p2)**2
        thresold = (max(n) - min(n)) * 0.5
        n1 = n > thresold
    
    def get_topk_label(self, dicts, count):
   
        final_result = []
        sorted_dic = sorted([(k, v) for k, v in dicts.items()], reverse=True)
        tmp_set = set()
        for item in sorted_dic:
            tmp_set.add(item[1])
        for list_item in sorted(tmp_set, reverse=True)[:count]:
            for dic_item in sorted_dic:
                if dic_item[1] == list_item:
                    final_result.append(dic_item[0])
        return final_result

    def select_groups_for_one_class(self):
        args, config = self.args, self.config
        logging.info(f"Loading diffusion model")
        print("Loading checkpoints {}".format(args.ckpt_folder))
        model = Conditional_Model(config)
        states = torch.load(
            os.path.join(args.ckpt_folder, "ckpts/ckpt.pth"),
            map_location=self.device,
        )
        model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(states[0], strict=True)
        #model_past = copy.deepcopy(model)
        optimizer = get_optimizer(config, model.parameters())
        channels = config.data.channels
        img_size = config.data.image_size
        #n = config.grouping.batch_num
        #c_forget = (torch.ones(n, dtype=int) * args.label_to_forget).to(self.device)
        #x_forget = ( torch.rand((n, channels, img_size, img_size), device=self.device) - 0.5 ) * 2.
        #e_forget = torch.randn_like(x_forget)
        b = self.betas
        data_loader = groups_for_one_path_dataset(config, os.path.join(args.ckpt_folder, "class_samples"), args.label_to_forget)
            
        # antithetic sampling
        #t = torch.randint(
        #    low=0, high=self.num_timesteps, size=(n // 2 + 1,)
        #).to(self.device)
        #t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
        #loss, _ = loss_registry_conditional[config.model.type](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.) 
        losses1 = {}
        for item in self.config.grouping.label_list:
            losses1[item] = 0
        for i, batch in enumerate(data_loader): 
            x_remember, c_remember = batch
            n = x_remember.size(0)
            t = torch.randint(
            low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            
            label = c_remember[0].item()
            #print(x_remember.shape)
            x_remember, c_remember = x_remember.to(self.device), c_remember.to(self.device)
            x_remember = data_transform(config, x_remember)
            e_remember = torch.randn_like(x_remember)
            loss1, _ = loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)            
            losses1[label] += loss1.item()
            optimizer.zero_grad()
            
        for j in range(config.grouping.num_samples):
            n = config.grouping.batch_num
            c_forget = (torch.ones(n, dtype=int) * args.label_to_forget).to(self.device)
            x_forget = ( torch.rand((n, channels, img_size, img_size), device=self.device) - 0.5 ) * 2.
            e_forget = torch.randn_like(x_forget)
            t = torch.randint(
            low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            loss, _ = loss_registry_conditional[config.model.type](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.) 
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        losses = {}
        for i, batch in enumerate(data_loader):
            
            x_remember, c_remember = batch
            n = x_remember.size(0)
            t = torch.randint(
            low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            
            label = c_remember[0].item()
            
            #print(c_remember)
            x_remember, c_remember = x_remember.to(self.device), c_remember.to(self.device)
            x_remember = data_transform(config, x_remember)
            e_remember = torch.randn_like(x_remember)
            loss, _ = loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)            
            losses[label] = 1 - losses1[label]/loss.item()
        group_labels = self.get_topk_label(losses, config.grouping.topk)
        e_labels = [i for i in config.grouping.label_list if i not in group_labels]
        return e_labels
            
            #print(label)
        
    def train_forget_old(self):
        
        args, config = self.args, self.config
        logging.info(f"Training diffusion forget with contrastive and EWC. Gamma: {config.training.gamma}, lambda: {config.training.lmbda}")
        D_train_loader = all_but_one_class_path_dataset(config, os.path.join(args.ckpt_folder, "class_samples"), args.label_to_forget)
        D_train_iter = cycle(D_train_loader)
        
        print("Loading checkpoints {}".format(args.ckpt_folder))
        model = Conditional_Model(config)
        states = torch.load(
            os.path.join(args.ckpt_folder, "ckpts/ckpt.pth"),
            map_location=self.device,
        )
        model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(states[0], strict=True)
        optimizer = get_optimizer(config, model.parameters())
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            #model = ema_helper.ema_copy(model_no_ema)
        else:
            ema_helper = None

        with open(os.path.join(args.ckpt_folder, 'fisher_dict.pkl'), 'rb') as f:
            fisher_dict = pickle.load(f)
        
        params_mle_dict = {}
        for name, param in model.named_parameters():
            params_mle_dict[name] = param.data.clone()
        
        label_choices = list(range(config.data.n_classes))
        label_choices.remove(args.label_to_forget)
        
        for step in range(0, config.training.n_iters):
            
            model.train()
            x_remember, c_remember = next(D_train_iter)
            x_remember, c_remember = x_remember.to(self.device), c_remember.to(self.device)
            x_remember = data_transform(config, x_remember)
            
            n = x_remember.size(0)
            channels = config.data.channels
            img_size = config.data.image_size
            c_forget = (torch.ones(n, dtype=int) * args.label_to_forget).to(self.device)
            x_forget = ( torch.rand((n, channels, img_size, img_size), device=self.device) - 0.5 ) * 2.
            e_remember = torch.randn_like(x_remember)
            e_forget = torch.randn_like(x_forget)
            b = self.betas
            
            # antithetic sampling
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            loss = loss_registry_conditional[config.model.type](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.) + \
                   config.training.gamma * loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)   
            forgetting_loss = loss.item()
            
            ewc_loss = 0.
            for name, param in model.named_parameters():
                _loss = fisher_dict[name].to(self.device) * (param - params_mle_dict[name].to(self.device)) ** 2
                loss += config.training.lmbda * _loss.sum()
                ewc_loss += config.training.lmbda * _loss.sum()

            if (step+1) % config.training.log_freq == 0:
                logging.info(
                    f"step: {step}, loss: {loss.item()}, forgetting loss: {forgetting_loss}, ewc loss: {ewc_loss}"
                )

            optimizer.zero_grad()
            loss.backward()

            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(model)

            if (step+1) % config.training.snapshot_freq == 0:
                states = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    #epoch,
                    step,
                ]
                if config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(
                    states,
                    os.path.join(config.ckpt_dir, "ckpt.pth"),
                )
                #torch.save(states, os.path.join(self.config.ckpt_dir, "ckpt_latest.pth"))

                test_model = ema_helper.ema_copy(model) if config.model.ema else copy.deepcopy(model)
                test_model.eval()
                self.sample_visualization(test_model, step, args.cond_scale)
                del test_model
    
    def train_forget_lora(self):
        
        args, config = self.args, self.config
        logging.info(f"Training diffusion forget with contrastive and EWC. Gamma: {config.training.gamma}, lambda: {config.training.lmbda}")
        D_train_loader = all_but_one_class_path_dataset(config, os.path.join(args.ckpt_folder, "class_samples"), args.label_to_forget)
        D_train_iter = cycle(D_train_loader)
        
        print("Loading checkpoints {}".format(args.ckpt_folder))
        model = Conditional_Model(config)
        states = torch.load(
            os.path.join(args.ckpt_folder, "ckpts/ckpt.pth"),
            #map_location= torch.device('cpu') #self.device,
        )
        #model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        #print(next(model.parameters()).device)
        model.load_state_dict(states[0], strict=False)
        #model = model.to(self.device)
        model = model.to(torch.device('cpu'))
        add_lora(model)
        model = model.to(self.device)
        parameters = [
        {"params": list(get_lora_params(model))},
        ]
        optimizer = get_optimizer(config, parameters)
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            #model = ema_helper.ema_copy(model_no_ema)
        else:
            ema_helper = None

        with open(os.path.join(args.ckpt_folder, 'fisher_dict.pkl'), 'rb') as f:
            fisher_dict = pickle.load(f)
        
        params_mle_dict = {}
        for name, param in model.named_parameters():
            params_mle_dict[name] = param.data.clone()
        
        label_choices = list(range(config.data.n_classes))
        label_choices.remove(args.label_to_forget)
        begin = time.time()
        for step in range(0, config.training.n_iters):
            
            model.train()
            x_remember, c_remember = next(D_train_iter)
            x_remember, c_remember = x_remember.to(self.device), c_remember.to(self.device)
            x_remember = data_transform(config, x_remember)
            
            n = x_remember.size(0)
            channels = config.data.channels
            img_size = config.data.image_size
            c_forget = (torch.ones(n, dtype=int) * args.label_to_forget).to(self.device)
            x_forget = ( torch.rand((n, channels, img_size, img_size), device=self.device) - 0.5 ) * 2.
            e_remember = torch.randn_like(x_remember)
            e_forget = torch.randn_like(x_forget)
            b = self.betas
            
            # antithetic sampling
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            #loss = loss_registry_conditional[config.model.type](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.) + \
            #       config.training.gamma * loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)   
            #forgetting_loss = loss.item()
            loss_forget = loss_registry_conditional[config.model.type](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.) 
            loss_forget.backward(retain_graph=True)
            #loss_forget_params = copy.deepcopy(model.named_parameters())
            loss_forget_params = {}
            for name, param in enumerate(list(get_lora_params(model))):
                #print(type(param.grad))
                if param.grad == None:
                    #print(param)
                    loss_forget_params[name] = torch.zeros_like(param, dtype=torch.float).to(self.device)
                else:
                    loss_forget_params[name] = param.grad
            optimizer.zero_grad()

            loss_remember = loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)  
            loss_remember.backward(retain_graph=True)
            loss_remember_params = {}
            for name, param in enumerate(list(get_lora_params(model))):
                #print(type(param.grad))
                if param.grad == None:
                    loss_remember_params[name] = torch.zeros_like(param, dtype=torch.float).to(self.device)
                else:
                    loss_remember_params[name] = param.grad
            
            loss = loss_forget + config.training.gamma * loss_remember
            ewc_loss = 0.

            effect_grads = {}
            grad_loss = 0.
            for name, param in enumerate(list(get_lora_params(model))):
                forget_a, index1 = torch.sort(abs(loss_forget_params[name]), descending=True)
                #print(forget_a)
                remember_b, index2 = torch.sort(abs(loss_remember_params[name]), descending=True)
                a_value = forget_a[int(len(forget_a)*config.grouping.part_params)]
                b_value = remember_b[int(len(remember_b)*config.grouping.part_params)]
                used_forget = torch.where(forget_a > a_value, forget_a, 0)
                used_remember = torch.where(remember_b>b_value, remember_b, 0)
                #print(used_forget.sum(), used_remember.sum())
                effect_grads[name] = abs(used_forget - used_remember) #torch.where((used_remember!=0) & (used_forget!=0), abs(used_forget - used_remember), 0).to(self.device)
                _loss = config.grouping.alpha * effect_grads[name].sum()
                _loss.requires_grad = True
                #print(_loss)
                #if name == 0:
                #    grad_loss = _loss
                #else:
                grad_loss += _loss
                #print(loss)
            #grad_loss.requires_grad = True
            loss = loss + grad_loss
            '''
            for name, param in model.named_parameters():
                _loss = fisher_dict[name].to(self.device) * (param - params_mle_dict[name].to(self.device)) ** 2
                loss += config.training.lmbda * _loss.sum()
                ewc_loss += config.training.lmbda * _loss.sum()
            '''
            if (step+1) % config.training.log_freq == 0:
                end = time.time()
                cost_time = end-begin
                logging.info(
                    f"step: {step}, loss: {loss_remember}, forgetting loss: {loss_forget}, ewc loss: {ewc_loss}, cost time: {cost_time}"
                )

            optimizer.zero_grad()
            #print(loss)
            loss.backward()

            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            
            optimizer.step()

            #if self.config.model.ema:
            #    ema_helper.update(model)

            if (step+1) % config.training.snapshot_freq == 0:
                states = [
                    #model.state_dict(),
                    get_lora_state_dict(model),
                    optimizer.state_dict(),
                    #epoch,
                    step,
                ]
                
                if config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(
                    states,
                    os.path.join(config.ckpt_dir, "lora_ckpt.pth"),
                )
                #torch.save(states, os.path.join(self.config.ckpt_dir, "ckpt_latest.pth"))
                #merge_lora(model)
                test_model = copy.deepcopy(model) #ema_helper.ema_copy(model) if config.model.ema else copy.deepcopy(model)
                test_model.eval()
                test_model = test_model.to(self.device)
                #print(test_model.device)
                #print(next(test_model.parameters()).device)
                self.sample_visualization(test_model, step, args.cond_scale)
                del test_model
    
    
    def train_forget(self):
        
        args, config = self.args, self.config
        logging.info(f"Training diffusion forget with contrastive and EWC. Gamma: {config.training.gamma}, lambda: {config.training.lmbda}")
        label_to_use = self.select_groups_for_one_class()

        D_train_loader = all_but_one_class_path_dataset(config, os.path.join(args.ckpt_folder, "class_samples"), label_to_use)#args.label_to_forget)
        All_train_loader = all_but_one_class_path_dataset(config, os.path.join(args.ckpt_folder, "class_samples"), args.label_to_forget)
        #config.grouping.label_list)#args.label_to_forget)
        D_train_iter = cycle(D_train_loader)
        All_train_iter = cycle(All_train_loader)
        #label_to_use = self.select_groups_for_one_class()
        print("Loading classifier")
        classifier = torchvision.models.resnet34(pretrained=False)

        num_ftrs = classifier.fc.in_features
        classifier.fc = nn.Linear(num_ftrs, 10)
        classifier.load_state_dict(torch.load("cifar10_resnet34.pth", map_location='cpu'))
        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        classifier = classifier.to(self.device)
        
        print("Loading checkpoints {}".format(args.ckpt_folder))
        model = Conditional_Model(config)
        states = torch.load(
            os.path.join(args.ckpt_folder, "ckpts/ckpt.pth"),
            map_location=self.device,
        )
        model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        params_random_dict = {}
        for name, param in model.named_parameters():
            params_random_dict[name] = param.data.clone()
            #print(name)
        
        model.load_state_dict(states[0], strict=True)
        
        
        for name, param in model.named_parameters():  #frozen the attention layer
            if config.grouping.frozen in name:
                param.requires_grad =False
               
        optimizer = get_optimizer(config, model.parameters())
        
        criterion = nn.CrossEntropyLoss()

        '''
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            #model = ema_helper.ema_copy(model_no_ema)
        else:
            ema_helper = None
        '''
        with open(os.path.join(args.ckpt_folder, 'fisher_dict.pkl'), 'rb') as f:
            fisher_dict = pickle.load(f)
        
        params_mle_dict = {}
        for name, param in model.named_parameters():
            params_mle_dict[name] = param.data.clone()
        
        #label_choices = list(range(config.data.n_classes))
        #label_choices.remove(args.label_to_forget)
        label_choices = label_to_use
        begin = time.time()
        cost_times = []
        print(label_to_use)
        #for name, param in model.named_parameters():
            #if 'attn' not in name:
        #    param.requires_grad = True
        
        for step in range(config.training.n_iters):
            
            model.train()
            if step % config.grouping.rem_step == 0:
                x_remember, c_remember = next(All_train_iter)
            else:
                x_remember, c_remember = next(D_train_iter)
            x_remember, c_remember = x_remember.to(self.device), c_remember.to(self.device)
            x_remember = data_transform(config, x_remember)
            
            n = x_remember.size(0)
            channels = config.data.channels
            img_size = config.data.image_size
            c_forget = (torch.ones(n, dtype=int) * args.label_to_forget).to(self.device)
            x_forget = ( torch.rand((n, channels, img_size, img_size), device=self.device) - 0.5 ) * 2.
            e_remember = torch.randn_like(x_remember)
            e_forget = torch.randn_like(x_forget)
            b = self.betas
            
            # antithetic sampling
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            #outputs = loss_registry_conditional['output'](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.)
            
            #forget_grad = torch.autograd.grad(preds, outputs, grad_outputs=torch.ones_like(preds), retain_graph=True)
            #print(class_forget_loss)
            loss_forget, outputs = loss_registry_conditional[config.model.type](model, x_forget, t, c_forget, e_forget, b, cond_drop_prob = 0.) 
            loss_forget.backward(retain_graph=True)
              # the outputs denote the image
            preds = classifier(outputs)
            #preds_grad = torch.autograd.grad(preds, outputs, grad_outputs=torch.ones_like(preds))
            #print(preds.shape)
            class_forget_loss = criterion(preds, c_forget)
            #loss_forget_params = copy.deepcopy(model.named_parameters())
            loss_forget_params = {}
            for name, param in model.named_parameters():
                #print(type(param.grad))
                if param.grad == None:
                    #print(param)
                    loss_forget_params[name] = torch.zeros_like(param, dtype=torch.float).to(self.device)
                else:
                    loss_forget_params[name] = param.grad
            #optimizer.zero_grad()
            #r_outputs = loss_registry_conditional['output'](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)
            a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
            #print(a.shape, x0.shape, e.shape)
            r_outputs = x_remember * a.sqrt() + e_remember * (1.0 - a).sqrt()
            r_outputs.requires_grad = True
            r_preds = classifier(r_outputs)
            preds_grad = torch.autograd.grad(r_preds, r_outputs, grad_outputs=torch.ones_like(r_preds), retain_graph=True)
            loss_remember, r_outputs = loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, e_remember, b, cond_drop_prob = 0.)  
            #loss_remember, _ = loss_registry_conditional[config.model.type](model, x_remember, t, c_remember, preds_grad[0], b, cond_drop_prob = 0.)  
            r_preds = classifier(r_outputs)
            #print(preds_grad[0].shape)
            #print(e_remember.shape)
            class_remember_loss = criterion(r_preds, c_remember)
            
            loss_remember.backward(retain_graph=True)
            loss_remember_params = {}
            for name, param in model.named_parameters():
                #print(type(param.grad))
                if param.grad == None:
                    loss_remember_params[name] = torch.zeros_like(param, dtype=torch.float).to(self.device)
                else:
                    loss_remember_params[name] = param.grad
            
            #optimizer.zero_grad()
            '''
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n+n]
            loss = loss_registry_conditional[config.model.type](model, torch.cat([x_forget, x_remember]), t, torch.cat([c_forget, c_remember]), torch.cat([e_forget, e_remember]), b, cond_drop_prob=0.)
            '''                                                  
            loss = loss_forget + config.training.gamma * loss_remember + config.grouping.delta * 1/(torch.exp(class_forget_loss) + 1)
            #+ config.grouping.delta * class_remember_loss 
            #+ config.grouping.delta * 1/(torch.exp(class_forget_loss) + 1) + class_remember_loss#config.grouping.delta * class_remember_loss #(1/(1+class_forget_loss) + class_remember_loss)
            
            
            ewc_loss = 0.
            effect_grads = {}
            for name, param in model.named_parameters():
                #print(loss_forget_params[name])
                forget_a, index1 = torch.sort(abs(loss_forget_params[name]), descending=True)
                remember_b, index2 = torch.sort(abs(loss_remember_params[name]), descending=True)
                a_value = forget_a[int(len(forget_a)*config.grouping.part_params)]
                b_value = remember_b[int(len(remember_b)*config.grouping.part_params)]
                used_forget = torch.where(forget_a > a_value, forget_a, 0)
                used_remember = torch.where(remember_b>b_value, remember_b, 0)
                effect_grads[name] = torch.where((used_remember!=0) & (used_forget!=0), used_forget * config.grouping.alpha +used_remember*(1-config.grouping.alpha), 0)
                #effect_grads[name] = torch.where((used_forget != used_remember) & (used_remember!=0) & (used_forget!=0), used_forget * config.grouping.alpha +used_remember*(1-config.grouping.alpha), used_forget).to(self.device)
                #effect_grads[name] = torch.where((effect_grads == 0) & (used_remember != 0), used_remember, effect_grads)
                gap_params = torch.where((effect_grads[name]==0) & (used_forget==0), ((param - params_mle_dict[name]).to(self.device)) ** 2, 0) 
                #print(param.grad)
                _loss = fisher_dict[name].to(self.device) * gap_params.to(self.device) #(param - params_mle_dict[name].to(self.device)) ** 2
                #print(_loss.sum())
                #if param.grad == None:
                #    continue
                #param.grad = effect_grads
                loss = config.training.lmbda * _loss.sum()
                ewc_loss += config.training.lmbda * _loss.sum()
                #if step == 0:
                #    #kaiming_param = nn.init.kaiming_uniform_(model.named_parameters())
                #    param = torch.where(effect_grads[name]!=0, params_mle_dict[name], param)
                
            if (step+1) % config.training.log_freq == 0:
                end = time.time()
                cost_time = end-begin
                cost_times.append(cost_time)
                print(class_remember_loss)
                logging.info(
                    f"step: {step}, loss: {loss.item()}, class loss: {(1/(torch.exp(class_forget_loss)+1).item())*config.grouping.delta}, forgetting loss: {loss_forget}, ewc loss: {ewc_loss}, cost time: {cost_time}"
                )

            #optimizer.zero_grad()
            loss.backward()
            
            
            for name, param in model.named_parameters():
                #if effect_grads[name].size() != param.grad.size():
                #    print(effect_grads[name])
                #    continue
                if param.grad != None:
                    pass
                    #param.grad = torch.where(effect_grads[name] == 0, param.grad, param.grad + effect_grads[name])
            
            optimizer.step()
            '''
            for name, param in model.named_parameters():
                if effect_grads[name].size() != param.grad.size():
                    print(effect_grads[name])
                    continue
                param = torch.where(effect_grads[name] == 0, params_mle_dict[name], param)
            '''
            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            '''
            if self.config.model.ema:
                ema_helper.update(model)
            '''
            if (step+1) % config.training.snapshot_freq == 0:
                states = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    #epoch,
                    step,
                ]
                '''
                if config.model.ema:
                    states.append(ema_helper.state_dict())
                '''
                np.save(os.path.join(config.ckpt_dir, "cost.npy"), np.array(cost_times))
                torch.save(
                    states,
                    os.path.join(config.ckpt_dir, "ckpt.pth"),
                )
                #torch.save(states, os.path.join(self.config.ckpt_dir, "ckpt_latest.pth"))

                test_model = ema_helper.ema_copy(model) if config.model.ema else copy.deepcopy(model)
                test_model.eval()
                self.sample_visualization(test_model, step, args.cond_scale)
                del test_model

    def load_ema_model(self):
        model = Conditional_Model(self.config)
        states = torch.load(
            os.path.join(self.args.ckpt_folder, "ckpts/ckpt.pth"),
            map_location=self.device,
        )
        model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(states[0], strict=True)
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            test_model = ema_helper.ema_copy(model)
        else:
            ema_helper = None

        model.eval()
        return model
    
    def sample(self):
        model = Conditional_Model(self.config)
        states = torch.load(
            os.path.join(self.args.ckpt_folder, "ckpts/ckpt.pth"),
            map_location=self.device,
        )
        model = model.to(self.device)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(states[0], strict=True)
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
            ema_helper.load_state_dict(states[-1])
            test_model = ema_helper.ema_copy(model)
        else:
            ema_helper = None
            test_model = copy.deepcopy(model)

        model.eval()
        
        if self.args.mode == 'sample_fid':
            self.sample_fid(test_model, self.args.cond_scale)
        elif self.args.mode == 'sample_classes':
            self.sample_classes(test_model, self.args.cond_scale)
        # elif self.args.mode == 'gen_data':
        #     self.sample_gen_data(test_model, self.args.cond_scale)
        elif self.args.mode == 'visualization':
            self.sample_visualization(test_model, str(self.args.cond_scale), self.args.cond_scale)
        # elif self.args.mode == 'one_class':
        #     self.sample_one_class(test_model, self.args.cond_scale, self.args.class_label)
    

    def sample_classes(self, model, cond_scale):
        """
        Samples each class from the model. Can be used to calculate FIM, for generative replay 
        or for classifier evaluation. Stores samples in "./class_samples/<class_label>".
        """
        config = self.config
        args = self.args
        sample_dir = os.path.join(args.ckpt_folder, "class_samples")
        os.makedirs(sample_dir, exist_ok=True)
        img_id = 0
        # total_n_samples = 5000
        # assert total_n_samples % config.data.n_classes == 0
        classes, _ = create_class_labels(args.classes_to_generate, n_classes=config.data.n_classes)
        n_samples_per_class = args.n_samples_per_class
        
        for i in classes:
            
            os.makedirs(os.path.join(sample_dir, str(i)), exist_ok=True)
            if n_samples_per_class % config.sampling.batch_size == 0:
                n_rounds = n_samples_per_class // config.sampling.batch_size 
            else:
                n_rounds = n_samples_per_class // config.sampling.batch_size  + 1 
            n_left = n_samples_per_class # tracker on how many samples left to generate
            
            with torch.no_grad():
                for j in tqdm.tqdm(
                    range(n_rounds), desc=f"Generating image samples for class {i} to use as dataset"
                ):
                    if n_left >= config.sampling.batch_size:
                        n = config.sampling.batch_size
                    else:
                        n = n_left
                    print('round:', j)
                    x = torch.randn(
                        n,
                        config.data.channels,
                        config.data.image_size,
                        config.data.image_size,
                        device=self.device,
                    )
                    c = torch.ones(x.size(0), device=self.device, dtype=int) * int(i)
                    x = self.sample_image(x, model, c, cond_scale)
                    x = inverse_data_transform(config, x)
                    
                    for k in range(n):
                        tvu.save_image(x[k], os.path.join(sample_dir, str(c[k].item()), f"{img_id}.png"), normalize=True)
                        img_id += 1
                    
                    n_left -= n
                    

    def sample_one_class(self, model, cond_scale, class_label):
        """
        Samples one class only for classifier evaluation.
        """
        config = self.config
        args = self.args
        sample_dir = os.path.join(args.ckpt_folder, "class_" + str(class_label))
        os.makedirs(sample_dir, exist_ok=True)
        img_id = 0
        total_n_samples = 500
            
        if total_n_samples % config.sampling.batch_size == 0:
            n_rounds = total_n_samples // config.sampling.batch_size 
        else:
            n_rounds = total_n_samples // config.sampling.batch_size  + 1 
        n_left = total_n_samples # tracker on how many samples left to generate
        
        with torch.no_grad():
            for j in tqdm.tqdm(
                range(n_rounds), desc=f"Generating image samples for class {class_label}"
            ):
                if n_left >= config.sampling.batch_size:
                    n = config.sampling.batch_size
                else:
                    n = n_left
                
                x = torch.randn(
                    n,
                    config.data.channels,
                    config.data.image_size,
                    config.data.image_size,
                    device=self.device,
                )
                c = torch.ones(x.size(0), device=self.device, dtype=int) * class_label
                x = self.sample_image(x, model, c, cond_scale)
                x = inverse_data_transform(config, x)
                
                for k in range(n):
                    tvu.save_image(x[k], os.path.join(sample_dir, f"{img_id}.png"), normalize=True)
                    img_id += 1
                
                n_left -= n
                
    def sample_fid(self, model, cond_scale):
        config = self.config
        args = self.args
        img_id = 0
        # total_n_samples = 45000
        # assert total_n_samples % (config.data.n_classes - 1) == 0
        # n_samples_per_class = total_n_samples // (config.data.n_classes-1)
        
        classes, excluded_classes = create_class_labels(args.classes_to_generate, n_classes=config.data.n_classes)
        n_samples_per_class = args.n_samples_per_class
        # classes = list(range(config.data.n_classes))
        # classes.remove(args.label_to_forget)
        
        sample_dir = f"fid_samples_guidance_{args.cond_scale}"
        if excluded_classes:
            excluded_classes_str = "_".join(str(i) for i in excluded_classes)
            sample_dir = f"{sample_dir}_excluded_class_{excluded_classes_str}"
        sample_dir = os.path.join(args.ckpt_folder, sample_dir)
        os.makedirs(sample_dir, exist_ok=True)
        
        for i in classes:
            
            if n_samples_per_class % config.sampling.batch_size == 0:
                n_rounds = n_samples_per_class // config.sampling.batch_size 
            else:
                n_rounds = n_samples_per_class // config.sampling.batch_size  + 1 
            n_left = n_samples_per_class # tracker on how many samples left to generate
            
            with torch.no_grad():
                for j in tqdm.tqdm(
                    range(n_rounds), desc=f"Generating image samples for class {i} for FID"
                ):
                    if n_left >= config.sampling.batch_size:
                        n = config.sampling.batch_size
                    else:
                        n = n_left
                    
                    x = torch.randn(
                        n,
                        config.data.channels,
                        config.data.image_size,
                        config.data.image_size,
                        device=self.device,
                    )
                    c = torch.ones(x.size(0), device=self.device, dtype=int) * int(i)
                    x = self.sample_image(x, model, c, cond_scale)
                    x = inverse_data_transform(config, x)
                    
                    for k in range(n):
                        tvu.save_image(x[k], os.path.join(sample_dir, f"{img_id}.png"), normalize=True)
                        img_id += 1
                    
                    n_left -= n

    def sample_image(self, x, model, c, cond_scale, last=True):
        try:
            skip = self.args.skip
        except Exception:
            skip = 1

        if self.args.sample_type == "generalized":
            if self.args.skip_type == "uniform":
                skip = self.num_timesteps // self.args.timesteps
                seq = range(0, self.num_timesteps, skip)
            elif self.args.skip_type == "quad":
                seq = (
                    np.linspace(
                        0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps
                    )
                    ** 2
                )
                seq = [int(s) for s in list(seq)]
            else:
                raise NotImplementedError
            from functions.denoising import generalized_steps_conditional

            xs = generalized_steps_conditional(x, c, seq, model, self.betas, cond_scale, eta=self.args.eta)
            x = xs
        elif self.args.sample_type == "ddpm_noisy":
            if self.args.skip_type == "uniform":
                skip = self.num_timesteps // self.args.timesteps
                seq = range(0, self.num_timesteps, skip)
            elif self.args.skip_type == "quad":
                seq = (
                    np.linspace(
                        0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps
                    )
                    ** 2
                )
                seq = [int(s) for s in list(seq)]
            else:
                raise NotImplementedError
            from functions.denoising import ddpm_steps_conditional

            x = ddpm_steps_conditional(x, c, seq, model, self.betas)
        else:
            raise NotImplementedError
        if last:
            x = x[0][-1]
        return x

    def sample_visualization(self, model, name, cond_scale):
        config = self.config
        total_n_samples = config.training.visualization_samples
        assert total_n_samples % config.data.n_classes == 0
        n_rounds = total_n_samples // config.sampling.batch_size if config.sampling.batch_size < total_n_samples else 1
        c = torch.repeat_interleave(torch.arange(config.data.n_classes), total_n_samples//config.data.n_classes)
        c_chunks = torch.chunk(c, n_rounds, dim=0)
        
        with torch.no_grad():
            all_imgs = []
            for i in tqdm.tqdm(
                range(n_rounds), desc="Generating image samples for visualization."
            ):
                c = c_chunks[i].to(self.device)
                n = c.size(0)
                x = torch.randn(
                    n,
                    config.data.channels,
                    config.data.image_size,
                    config.data.image_size,
                    device=self.device,
                )
                #print(x.device)
                x = self.sample_image(x, model, c, cond_scale)
                x = inverse_data_transform(config, x)

                all_imgs.append(x)
            
            all_imgs = torch.cat(all_imgs)
            grid = tvu.make_grid(all_imgs, nrow=total_n_samples//config.data.n_classes, normalize=True, padding=0)
            
            try:
                tvu.save_image(grid, os.path.join(self.config.log_dir, f'sample-{name}.png')) # if called during training of base model
            except AttributeError:
                tvu.save_image(grid, os.path.join(self.args.ckpt_folder, f'sample-{name}.png')) # if called from sample.py

