import os
import torch
import random
import copy
import numpy as np
import matplotlib

matplotlib.use("GTK3Agg")
from torchvision import models, transforms
from torch.utils import data
from datasets.waterbirds import Waterbirds
from datasets.urbancars import UrbanCars
from datasets.bar import BAR
from datasets.BFFHQ import BFFHQ
import torch.optim.lr_scheduler as lr_scheduler
from unbiased_case import get_real_cifar, read_DDPM_unbiased_images
from erm_training import evaluate_model
from utils.metrics import *
from tqdm import tqdm

from utils.wandb_wrapper import WandbWrapper
wb = WandbWrapper("LFF_Debiasing")

import torch.nn as nn

class FromNpyDataset(torch.utils.data.Dataset):
    def __init__(self, data, targets, transform=None, num_biases=1):
        self.data = data
        self.targets = targets
        self.transform = transform
        self.num_biases = num_biases

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        if self.transform:
            x = self.transform(x)
        return x, tuple(y for _ in range(self.num_biases+1)), index

    def __len__(self):
        return len(self.data)

    def __len__(self):
        return len(self.data)


def set_seed(seed):
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

class EMA:
    def __init__(self, label, alpha=0.9):
        self.label = torch.as_tensor(label).cuda()
        self.alpha = torch.Tensor([alpha]).cuda()
        self.parameter = torch.zeros(label.size(0)).cuda()
        self.updated = torch.zeros(label.size(0)).cuda()

    def update(self, data, index):
        index = torch.as_tensor(index).cuda()
        self.parameter[index] = self.alpha * self.parameter[index] + (1-self.alpha*self.updated[index]) * data.cuda()
        self.updated[index] = 1

    def max_loss(self, label):
        label_index = torch.where(self.label == label)[0]
        return self.parameter[label_index].max()



class lld (torch.nn.Module):

    def __init__(
        self, num_classes, model, weight_path, dataset="waterbirds", *args, **kwargs
    ):
        super().__init__()
        self.model = copy.deepcopy(model)
        self.model.to("cuda")
        if dataset=='cifar10':
            self.transform = transforms.Resize((32,32),interpolation=transforms.InterpolationMode.BICUBIC)
        else:
            self.transform = transforms.Resize((64,64),interpolation=transforms.InterpolationMode.BICUBIC)
            
        self.model.load_state_dict(torch.load(weight_path))
        self.loss = torch.nn.CrossEntropyLoss(reduction='none')
    
    def forward(self, x,target=None):
        self.model.eval()
        with torch.no_grad():
            x = self.transform(x)
            
            if target is None:
                return self.model(x)
            
            return self.model(x),self.loss(self.model(x),target)


PATH_TO_MODELS = f'./saved_models'

def train_biased_model(
    model: torch.nn.Module,
    train_loader,
    val_loader,
    device,
    optimizer,
    num_classes,
    epochs=10,
    make_figures=False,
    name="Biased_model",
    dataset="cifar",
):
    print("Starting Biased Model training...")
    with open(f"results_{dataset}.txt", "a") as f:
        f.write("Biased Model Training\n")
    model = model.to("cuda")
    train_target_attr = list()
    indexes = list()
    cur_model_name = name
    scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

    for batch, (dat, labels, index) in enumerate(train_loader):
        train_target_attr.append(labels)
        indexes.append(index)
    
    for epoch in tqdm(range(epochs)):
        model.train()
        with torch.enable_grad():
            for batch, (dat, labels, index) in enumerate(train_loader):
                dat = dat.to(device)
                targets = labels[0].type(torch.LongTensor)
                target = targets.to(device)
                output = model(dat)
                loss_task: torch.Tensor = model.loss_fn(output, target)
                loss_task.mean().backward()
                optimizer.step()
                optimizer.zero_grad()


        if np.mod(epoch,20)==0:
        
            with torch.no_grad():
                model.eval()
                label_results_aligned = torch.zeros((10,10))
                label_results_conflicting = torch.zeros((10,10))
                label_results_bar= torch.zeros((6,6))
                loss_per_class_aligned = torch.zeros((2,1))
                loss_per_class_conflict = torch.zeros((2,1))

                for batch, (dat,labels,_) in enumerate(val_loader):
                    dat = dat.to(device)
                    # dat = torch.clip(dat,min=-1,max=1)

                    target = labels[0].to(device).long()
                    bias_l = labels[1].to(device)
                    output = model(dat)
                    loss_sample = torch.nn.CrossEntropyLoss(reduction="none")(output, target)

            #metrics
                    if dataset=='bar':
                        for i in range(0,dat.shape[0]):
                            label_results_bar[target[i],torch.argmax(output[i])]+=1
                    else:
                        for i in range(0,dat.shape[0]):
                            if bias_l[i]==target[i]:
                                label_results_aligned[target[i],torch.argmax(output[i])]+=1
                                # loss_per_class_aligned[target[i].item()].append(loss_sample[i].item())
                                if target[i]==torch.argmax(output[i]):
                                    loss_per_class_aligned[0]+=loss_sample[i].item()
                                else:
                                    loss_per_class_aligned[1]+=loss_sample[i].item()
                            else:
                                label_results_conflicting[target[i],torch.argmax(output[i])]+=1
                                # loss_per_class_conflict[target[i].item()].append(loss_sample[i].item())
                                if target[i]==torch.argmax(output[i]):
                                    loss_per_class_conflict[0]+=loss_sample[i].item()
                                else:
                                    loss_per_class_conflict[1]+=loss_sample[i].item()

                print(f'epoch: {epoch}')
                A = label_results_aligned.cpu().numpy()
                B = label_results_conflicting.cpu().numpy()
                label_results_bar = label_results_bar.cpu().numpy()
                C=A+B
                print(f'average_aligned {np.trace(A)/np.sum(A)}')
                print(f'average_conflicting {np.trace(B)/np.sum(B)}')
                if dataset == 'waterbirds' or dataset =='BFFHQ':
                    print(f'ALIGNED_class_0 {A[0,0]/(A[0,0]+A[0,1])}')
                    print(f'ALIGNED_class_1 {A[1,1]/(A[1,0]+A[1,1])}')
                    print(f'CONFLICTING_class_0 {B[0,0]/(B[0,0]+B[0,1])}')
                    print(f'CONFLICTING_class_1 {B[1,1]/(B[1,0]+B[1,1])}')
                elif dataset == 'bar':
                    print(f'average {np.trace(label_results_bar)/np.sum(label_results_bar)}')
                    for i in range(0,6):
                     print(label_results_bar[i,i]/np.sum(label_results_bar,axis=1)[i])
                else:
                    # for i in range(0,10):
                    print(f'Aligned correctly {loss_per_class_aligned[0]},wrong {loss_per_class_aligned[1]}')
                    print(f'Conflict correctly {loss_per_class_conflict[0]},wrong {loss_per_class_conflict[1]}')
                        # print(f'aligned{A[i,i]/np.sum(A,axis=1)[i]}-conflicting{B[i,i]/np.sum(B,axis=1)[i]}')

    torch.save(model.state_dict(), os.path.join(PATH_TO_MODELS, cur_model_name))
    #model.load_state_dict(torch.load(os.path.join(PATH_TO_MODELS, cur_model_name), weights_only=True))
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)

def train_model(model: torch.nn.Module, BIASED_MODEL,train_loader, val_loader,test_loader,device, optimizer, num_classes, epochs=10, dataset = 'cifar', wb=wb, make_figures=True,file_name="tempt.csv"):
    cur_model_name = "DDPM_biased_model.pt"
    next_model_name = "DDPM_ft_model.pt"
    BIASED_MODEL.eval()
    with open(f'results_{dataset}.txt', 'a') as f:
        f.write('\n')
        f.write('\n')
        f.write('Debiasing')
    scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)

    print(f"Saving {cur_model_name}")
    train_target_attr = list()
    indexes = list()

    for batch, (dat, labels, index) in enumerate(test_loader):
            train_target_attr.append(labels[0])
            index = index.to(device)
            indexes.append(index)
            
    train_target_attr = torch.cat(train_target_attr).to(device)

    if dataset=='cmnist' or dataset=='BFFHQ':
        train_target_attr = train_target_attr.long()
    indexes = torch.cat(indexes)
    sample_loss_ema_b = EMA(train_target_attr.long(), alpha=0.95)
    sample_loss_ema_d = EMA(train_target_attr.long(), alpha=0.95)
    for epoch in range(0,epochs):
        loss_task_tot = AverageMeter()
        top1 = AverageMeter()
        
        subgroups_shape = (num_classes,) * (1+2) #if num_classes <= 2 else (num_classes, 2)

        subgroup_top1 = AverageMeterSubgroups(subgroups_shape, device=device)
        tk0 = tqdm(
            test_loader, total=int(len(test_loader)), leave=True, dynamic_ncols=True
        )
        
        al_b_loss_meter = AverageMeter()
        al_u_loss_meter = AverageMeter()
        co_b_loss_meter = AverageMeter()
        co_u_loss_meter = AverageMeter()
        
        if dataset=='waterbirds':
            accum_iter = 16
        elif dataset=='bar':
            accum_iter = 8
        elif dataset == "BFFHQ":
            accum_iter = 16
        else:
            accum_iter=4
        
        model.train()        
        with torch.enable_grad():
            for batch, (dat, labels, index) in enumerate(tk0):
                dat = dat.to(device)
                target = labels[0].to(device).long()
                output = model(dat)
                if len(labels) > 2:
                    labels = torch.vstack(labels)
                    conf_mask = (labels[0] != labels[1]) | (labels[0] != labels[2])

                pred_b, loss_b = BIASED_MODEL(dat,target)
                loss_u = model.loss_fn(output, target).detach()
                pred_b = torch.argmax(pred_b, dim=1)     
                loss_b = loss_b.detach()                           

                sample_loss_ema_b.update(loss_b, index)
                sample_loss_ema_d.update(loss_u, index)

                # # class-wise normalize
                loss_b = sample_loss_ema_b.parameter[index].clone().detach()
                loss_u = sample_loss_ema_d.parameter[index].clone().detach()

                if np.isnan(loss_b.mean().item()):
                    raise NameError("loss_b_ema")
                if np.isnan(loss_u.mean().item()):
                    raise NameError('loss_d_ema')

                for c in range(num_classes):
                    class_index = torch.where(target == c)[0]
                    max_loss_b = sample_loss_ema_b.max_loss(c)
                    max_loss_d = sample_loss_ema_d.max_loss(c)
                    loss_b[class_index] /= max_loss_b
                    loss_u[class_index] /= max_loss_d        
                
                    
                loss_task: torch.Tensor = (loss_b/(loss_b+loss_u + 1e-8) + torch.exp((torch.Tensor([epoch]).cuda()))) * model.loss_fn(output, target)
                # loss_task: torch.Tensor = (loss_b/(loss_b+loss_u + 1e-8)) * loss_u
                loss_task = loss_task.mean()
                loss_task.backward()
                loss_task_tot.update(loss_task.item(), n=dat.size(0))
                
                #  -----> Gradient-Accumulation <-----
                if ((batch + 1) % accum_iter == 0) or (batch + 1 == len(test_loader)):
                    optimizer.step()
                    optimizer.zero_grad()
                #  -----> Gradient-Accumulation <-----

                acc1 = accuracy(output, target)
                acc_a = regroup_by(subgroup_top1, ("aligned", ))
                acc_c = regroup_by(subgroup_top1, ("misaligned", ))
                
                subgroup_masks = get_subgroup_masks(labels, num_classes=subgroups_shape, device=device)
                subgroup_acc1 = accuracy_subgroup(output, target, subgroup_masks, num_classes=num_classes)
                
                top1.update(acc1[0], dat.size(0))
                subgroup_top1.update(subgroup_acc1, subgroup_masks)
                
                al_b_loss_meter.update(loss_b[~conf_mask].mean().item())
                al_u_loss_meter.update(loss_u[~conf_mask].mean().item())
                co_b_loss_meter.update(loss_b[conf_mask].mean().item())
                co_u_loss_meter.update(loss_u[conf_mask].mean().item())

                postifix_dict = {
                    "epoch": epoch,
                    "acc1": top1.avg,
                    "lr": optimizer.param_groups[0]['lr'],
                    "acc_a": acc_a[0].item(),
                    "acc_c": acc_c[0].item()
                }                    
                subgroup_avg = subgroup_top1.avg
                
                if len(subgroup_top1.avg.size()) > 2:
                    for cl in range(subgroup_avg.size(0)):
                        for b0 in range(subgroup_avg.size(1)):
                            for b1 in range(subgroup_avg.size(2)):
                                value = subgroup_avg[cl, b0, b1].item()
                                if value <= 0:
                                    continue
                                postifix_dict[f"({4*cl+2*b0+b1})"] = value
                else:
                    for cl in range(subgroup_avg.size(0)):
                        for g in range(subgroup_avg.size(1)):
                            value = subgroup_avg[cl, g].item()
                            if value <= 0: 
                                continue
                            postifix_dict[f"({cl},{g})"] = value
                postifix_dict["loss"] = loss_task_tot.avg

                iter_string = f"Training Set Epoch {epoch} (iter {(epoch+1) * batch}): \n"
                for key in postifix_dict.keys():
                    iter_string += f"{key}:\t {postifix_dict[key]}\n"                        

                # tk0.write(iter_string)
                tk0.set_postfix(postifix_dict)
                
                postifix_dict["al_b_loss_meter"] = al_b_loss_meter.avg
                postifix_dict["al_u_loss_meter"] = al_u_loss_meter.avg
                postifix_dict["co_b_loss_meter"] = co_b_loss_meter.avg
                postifix_dict["co_u_loss_meter"] = co_u_loss_meter.avg
                
                
                torch.save(model.state_dict(), os.path.join(PATH_TO_MODELS, cur_model_name))
        
        if wb is not None:
            wb.log_output(postifix_dict)
                
                
        evaluate_model(model, val_loader, num_classes, num_biases=2, criterion=torch.nn.CrossEntropyLoss(reduction="mean"), epoch=epoch, device=device, wb=wb, prefix="te")
        # scheduler.step()


def create_loaders(dataset="waterbirds", de_bias=1, amount_bias=95):
    if dataset == "waterbirds":
        # transform
        if de_bias == 0:
            eval_transform = transforms.Compose(
                [
                    transforms.Resize(
                        (64, 64), interpolation=transforms.InterpolationMode.BICUBIC
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        else:
            eval_transform = transforms.Compose(
                [
                    transforms.Resize(
                        (224, 224), interpolation=transforms.InterpolationMode.BICUBIC
                    ),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        #############
        train_set = Waterbirds(root="./data", env="train", transform=eval_transform)
        val_set = Waterbirds(root="./data", env="val", transform=eval_transform)
        test_set = Waterbirds(root="./data", env="test", transform=eval_transform)

        class_sample_count = train_set.perclass_populations()
        weight = 1.0 / class_sample_count
        targets = list()
        for i in range(0, len(train_set)):
            targets.append(train_set.samples[i]["class_label"])
        targets = np.array(targets)
        samples_weight = np.array([weight[t] for t in targets])
        samples_weight = torch.from_numpy(samples_weight)
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=samples_weight, num_samples=len(samples_weight), replacement=True
        )
        if de_bias == 0:
            train_loader = data.DataLoader(
                train_set,
                batch_size=16,
                drop_last=False,
                shuffle=False,
                pin_memory=True,
                num_workers=4,
            )
        else:
            train_loader = data.DataLoader(train_set, batch_size=16, drop_last=False, sampler=sampler, pin_memory=True, num_workers=4)
        val_loader   = data.DataLoader(val_set, batch_size=16, shuffle=False, pin_memory=True, num_workers=4)
        test_loader  = data.DataLoader(test_set, batch_size=256, shuffle=False, pin_memory=True, num_workers=4)
        
    elif dataset == "urbancars":
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        eval_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        
        train_set = UrbanCars(env="train", transform=eval_transform)
        val_set   = UrbanCars(env="val", transform=eval_transform)
        test_set  = UrbanCars(env="test", transform=eval_transform)
        if de_bias == 0:
            train_loader = data.DataLoader(train_set, batch_size=128, shuffle=False, num_workers=4)
        else:
            train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
        
        val_loader   = data.DataLoader(val_set, batch_size=128, shuffle=False, num_workers=4)
        test_loader  = data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)

    elif dataset == "unbiased_cifar10":
        return get_real_cifar()

    elif dataset == "bar":
        if de_bias == 0:
            train_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop((64, 64)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                    ),
                ]
            )
        else:
            train_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop((224, 224)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        eval_transform = transforms.Compose(
            [
                transforms.Resize(
                    (224, 224), interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        train_set = BAR(
            root=os.path.join("data", "bar"),
            env="train",
            transform=train_transform,
            return_index=True,
        )
        val_set = BAR(
            root=os.path.join("data", "bar"), env="val", transform=eval_transform, return_index=True
        )
        test_set = BAR(
            root=os.path.join("data", "bar"), env="test", transform=eval_transform, return_index=True
        )

        a = BAR(root=os.path.join("data", "bar"), env="train")
        class_sample_count = a.perclass_populations()
        weight = 1.0 / class_sample_count
        targets = list()
        for i in range(0, len(train_set)):
            targets.append(train_set.samples[i]["class_label"])
        targets = np.array(targets)
        samples_weight = np.array([weight[t] for t in targets])
        samples_weight = torch.from_numpy(samples_weight)
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=samples_weight, num_samples=len(samples_weight), replacement=True
        )

        if de_bias == 0:
            train_loader = data.DataLoader(
                train_set,
                batch_size=128,
                drop_last=False,
                shuffle=False,
                pin_memory=True,
                num_workers=4,
            )
        else:
            train_loader = data.DataLoader(
                train_set,
                batch_size=16,
                drop_last=False,
                sampler=sampler,
                pin_memory=True,
                num_workers=4,
            )
        val_loader = data.DataLoader(
            val_set, batch_size=256, shuffle=False, pin_memory=True, num_workers=4
        )
        test_loader = data.DataLoader(
            test_set, batch_size=256, shuffle=False, pin_memory=True, num_workers=4
        )

    elif dataset == "BFFHQ":
        if de_bias == 0:
            train_transform = transforms.Compose(
                [
                    transforms.Resize((64, 64)),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                    ),  
                ]
            )
        else:
            train_transform = transforms.Compose(
                [
                    transforms.Resize(
                        (224, 224),
                        interpolation=transforms.InterpolationMode.BICUBIC,
                        antialias=True,
                    ),
                    transforms.RandomCrop(224, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        eval_transform = transforms.Compose(
            [
                transforms.Resize(
                    (224, 224), interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        train_set = BFFHQ(
            root=os.path.join("data", "bffhq"),
            env="train",
            transform=train_transform,
            return_index=True,
        )
        val_set = BFFHQ(
            root=os.path.join("data", "bffhq"), env="val", transform=eval_transform, return_index=True
        )
        test_set = BFFHQ(
            root=os.path.join("data", "bffhq"),
            env="test",
            transform=eval_transform,
            return_index=True,
        )

        if de_bias == 0:
            train_loader = data.DataLoader(
                train_set,
                batch_size=128,
                drop_last=False,
                shuffle=False,
                pin_memory=True,
                num_workers=4,
            )
        else:
            train_loader = data.DataLoader(
                train_set,
                batch_size=16,
                drop_last=False,
                shuffle=True,
                pin_memory=True,
                num_workers=4,
            )
        val_loader = data.DataLoader(
            val_set, batch_size=64, shuffle=False, pin_memory=True, num_workers=4
        )
        test_loader = data.DataLoader(
            test_set, batch_size=64, shuffle=False, pin_memory=True, num_workers=4
        )

    return train_loader, val_loader, test_loader, eval_transform


def create_model_for_bias(dataset="waterbirds"):
    if dataset == "waterbirds":
        model_for_bias = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        model_for_bias.classifier = torch.nn.Linear(
            model_for_bias.classifier.in_features, 2
        )
        model_for_bias.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
        model_for_bias = model_for_bias.to("cuda")
    elif dataset == "urbancars":
        model_for_bias = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        model_for_bias.classifier = torch.nn.Linear(model_for_bias.classifier.in_features, 2)
        model_for_bias.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
        model_for_bias = model_for_bias.to("cuda")
    elif dataset == "unbiased_cifar10":
        model_for_bias = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        model_for_bias.classifier = torch.nn.Linear(
            model_for_bias.classifier.in_features, 10
        )
        model_for_bias.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
        model_for_bias = model_for_bias.to("cuda")
    elif dataset == "bar":
        model_for_bias = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        model_for_bias.classifier = torch.nn.Linear(
            model_for_bias.classifier.in_features, 6
        )
        model_for_bias.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
        model_for_bias = model_for_bias.to("cuda")
    elif dataset == "BFFHQ":
        model_for_bias = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
        model_for_bias.classifier = torch.nn.Linear(
            model_for_bias.classifier.in_features, 2
        )
        model_for_bias.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
        model_for_bias = model_for_bias.to("cuda")
    return model_for_bias


def debiasing_model(dataset="waterbirds"):
    if dataset == "waterbirds":
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, 2)
        model = model.to("cuda")
        model.loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    elif dataset == "urbancars":
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, 2)
        model = model.to("cuda")
        model.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    elif dataset == "unbiased_cifar10":
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, 10)
        model = model.to("cuda")
        model.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    elif dataset == "bar":
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, 6)
        model = model.to("cuda")
        model.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")

    elif dataset == "BFFHQ":
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = torch.nn.Linear(model.fc.in_features, 2)
        model = model.to("cuda")
        model.loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
    return model


def read_DDPM_images(dataset="waterbirds", num_images=1000, alternative_path=None):
    if dataset == "waterbirds":
        address = None
        if alternative_path is not None:
            address = alternative_path
        else:
            address = "./data/synthetic/w_1/imagenet/waterbirds"
        train_transform = transforms.Compose(
            [
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        print("Alternative Path:", alternative_path)
        files = os.listdir(address)
        x_train = list()
        y_train = list()

        for i in range(0, len(files)):
            if "npy" in files[i] and "class_0" in files[i]:
                DDPM_images = np.load(os.path.join(address, files[i]))
                x_train.append(np.moveaxis(DDPM_images, 3, 1))
                y_train.append(np.ones(len(DDPM_images)) - 1)
        for i in range(0, len(files)):
            if "npy" in files[i] and "class_1" in files[i]:
                DDPM_images = np.load(os.path.join(address, files[i]))
                x_train.append(np.moveaxis(DDPM_images, 3, 1))
                y_train.append(np.ones(len(DDPM_images)))

            x_train = torch.Tensor(np.concatenate(x_train,axis=0))
            y_train = torch.Tensor(np.concatenate(y_train,axis=0)).int()
            dataset_DDPM = FromNpyDataset(x_train,y_train, train_transform)
            data_loader_DDPM_train = data.DataLoader(dataset_DDPM,batch_size=256,shuffle=True,pin_memory=True,num_workers=4)
            
    elif dataset == "urbancars":
        transform = transforms.Compose([
            transforms.Resize((64, 64))
        ])
        address = os.path.join("data", "synthetic", "w_1", "imagenet", "urbancars")
        files = os.listdir(address)
        x_train = list()
        y_train = list()
        for i in range(0, len(files)):
            if "npy" in files[i]:
                DDPM_images = np.load(os.path.join(address, files[i]))
                x_train.append(torch.from_numpy(np.moveaxis(DDPM_images, 3, 1)))
                y_train.append(torch.ones(len(DDPM_images)) * int(files[i][-5]))
                
        x_train = torch.Tensor(torch.concatenate(x_train, dim=0))
        y_train = torch.Tensor(torch.concatenate(y_train, dim=0)).long()
        dataset_DDPM = FromNpyDataset(x_train, y_train, transform, num_biases=2)
        data_loader_DDPM_train = data.DataLoader(dataset_DDPM, batch_size=256, shuffle=True, num_workers=4)
    
    elif dataset=='unbiased_cifar10':
        transform = transforms.Compose([
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        ])
        address = os.path.join("data", "synthetic", "unbiased", "cifar10")
        files = os.listdir(address)
        x_train = list()
        y_train = list()
        for i in range(0, len(files)):
            if "npy" in files[i]:
                DDPM_images = np.load(os.path.join(address, files[i]))
                x_train.append(torch.from_numpy(np.moveaxis(DDPM_images, 3, 1)))
                y_train.append(torch.ones(len(DDPM_images)) * int(files[i][-5]))

        x_train = torch.Tensor(torch.concatenate(x_train, axis=0))
        y_train = torch.Tensor(torch.concatenate(y_train, axis=0)).long()
        dataset_DDPM = FromNpyDataset(x_train, y_train)
        data_loader_DDPM_train = data.DataLoader(dataset_DDPM, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)      

    elif dataset == "bar":
        address = os.path.join("data", "synthetic", "w_1", "imagenet", "bar") 
        files = os.listdir(address)
        x_train = list()
        y_train = list()
        eval_transform = transforms.Compose([
            transforms.RandomResizedCrop((64,64)),
            transforms.RandomHorizontalFlip(),
                                                ])

        stds = (0.229, 0.224, 0.225)
        means=(0.485, 0.456, 0.406)
        for i in range(0,len(files)):
            if 'npy' in files[i]:
                DDPM_images = np.load(os.path.join(address,files[i]))
                x_train.append(np.moveaxis(DDPM_images,3,1))
                y_train.append(np.ones(len(DDPM_images))*int(files[i][-5]))

        x_train = torch.Tensor(np.concatenate(x_train,axis=0))
        y_train = torch.Tensor(np.concatenate(y_train,axis=0)).long()
        dataset_DDPM = FromNpyDataset(x_train,y_train,eval_transform)
        data_loader_DDPM_train = data.DataLoader(dataset_DDPM,batch_size=128,shuffle=True,pin_memory=True,num_workers=4)

    elif dataset == "BFFHQ":

        train_transform = transforms.Compose(
            [
                transforms.Resize(
                    (64, 64), interpolation=transforms.InterpolationMode.BICUBIC
                ),
                transforms.RandomHorizontalFlip(),
            ]
        )
        address = os.path.join("data", "synthetic", "w_1", "imagenet", dataset_name)
        files = os.listdir(address)
        x_train = list()
        y_train = list()
        for i in range(0, len(files)):
            if "npy" in files[i]:
                DDPM_images = np.load(os.path.join(address, files[i]))
                x_train.append(np.moveaxis(DDPM_images, 3, 1))

                y_train.append(np.ones(len(DDPM_images)) * int(files[i][-5]))
        x_train = torch.Tensor(np.concatenate(x_train, axis=0))
        y_train = torch.Tensor(np.concatenate(y_train, axis=0)).long()
        dataset_DDPM = FromNpyDataset(x_train, y_train, train_transform)
        data_loader_DDPM_train = data.DataLoader(
            dataset_DDPM, batch_size=64, shuffle=True, pin_memory=True, num_workers=4
        )

    return data_loader_DDPM_train


def Learning_DEbias(
    train_bias=1,
    train_debias=1,
    dataset="waterbirds",
    epochs_bias_model=25,
    epochs_debiasing=50,
    amount_bias=95,
    num_images=1000,
    alternative_path=None,
):
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model_for_bias = create_model_for_bias(dataset)
    model = debiasing_model(dataset)
    name = f'biased-final.pt'
    
    data_loader_DDPM_train = read_DDPM_images(dataset, num_images=num_images, alternative_path=alternative_path)
    if dataset=='waterbirds':
        num_classes=2
    elif dataset == "urbancars":
        num_classes = 2
    elif dataset=='bar':
        num_classes=6
    elif dataset=='BFFHQ':
        num_classes=2
    else:
        num_classes = 10

    if dataset == "bar":
        optim_biased = torch.optim.AdamW(
            model_for_bias.parameters(), lr=0.00005, weight_decay=0.01
        )

    elif dataset =='waterbirds':
        optim_biased = torch.optim.AdamW(model_for_bias.parameters(), lr=0.0005, weight_decay = 0.01)
        
    elif dataset =='urbancars':
        optim_biased = torch.optim.AdamW(model_for_bias.parameters(), lr=0.0001, weight_decay = 0.01)

    elif dataset == "BFFHQ":
        optim_biased = torch.optim.AdamW(
            model_for_bias.parameters(), lr=0.0005, weight_decay=0.01
        )
    elif dataset_name == "unbiased_cifar10":
        optim_biased = torch.optim.AdamW(model_for_bias.parameters(), lr=0.0005, weight_decay = 0.01)



    if train_bias==1:    
        train_loader,val_loader,test_loader,_ = create_loaders(dataset,0,amount_bias)
        name = f'biased-final.pt'#_model_95_{dataset}-final.pt'
        from erm_training import train_model_erm
        model_for_bias.loss_fn = torch.nn.CrossEntropyLoss()
        train_model_erm(model_for_bias, data_loader_DDPM_train, None, None, "cuda", optim_biased, 2, 2, epochs=5, wb=wb, name="biased")
        #setting optimizer and learning rate

        
    if dataset == "unbiased_cifar10":
        BIASED_MODEL = lld(
            10,
            model_for_bias,
            os.path.join(PATH_TO_MODELS, "biased_model_unbiased_cifar10-final.pt"),
        )
    else:
        BIASED_MODEL = lld(10,model_for_bias,os.path.join(PATH_TO_MODELS, name), dataset)

    
    if dataset=='bar':
        optim = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay = 0.01)

    elif dataset =='waterbirds':
        optim = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay = 0.01)
    
    elif dataset == "urbancars":
        optim = torch.optim.AdamW(model.parameters(), lr=1e-04, weight_decay=1e-04)

    elif dataset =='BFFHQ':
        optim = torch.optim.AdamW(model.parameters(), lr=0.00005, weight_decay = 1.0)
    elif dataset_name == "unbiased_cifar10":
        optim = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.01)

    if train_debias==1:
        train_loader,val_loader,test_loader,_ = create_loaders(dataset,1,amount_bias)
        evaluate_model(BIASED_MODEL, train_loader, 2, 2, torch.nn.CrossEntropyLoss(), "inference", "cuda", wb=wb, prefix="ba")
        train_model(
            model,
            BIASED_MODEL,
            data_loader_DDPM_train,
            test_loader,
            train_loader,
            DEVICE,
            optimizer=optim,
            num_classes=num_classes,
            dataset=dataset,
            epochs=epochs_debiasing,
        )


from argparse import Namespace, ArgumentParser

parser = ArgumentParser()
parser.add_argument("--dataset", type=str, default="waterbirds", required=True, help="dataset name.")
parser.add_argument("--retrain", action="store_true", help="repeat experiment and overwrite biased model, default=False")
parser.add_argument("--seed", type=int, default=0, help="random state for stochastic operations")
parser.add_argument("--rho", type=float, default=95)
parser.add_argument(
    "--subset_size",
    type=float,
    default=None,
    help="Subset size of synthetic images for training the biased model",
)
parser.add_argument(
    "--guidance_strength",
    type=int,
    default=1,
    help="Ablation study on classifier strength 'w'. Default=1, choose among [0, 1, 2, 3, 5]",
)


if __name__ == "__main__":
    args = parser.parse_args()
    dataset_name = args.dataset
    set_seed(args.seed)
    amount_bias = args.rho
    guidance_strength = args.guidance_strength
    retrain = args.retrain

    with open(f"results_{dataset_name}.txt", "a") as f:
        f.write(f"Args: \n {args}\n")
        
    if guidance_strength == 1 and dataset_name != "unbiased_cifar10":
        if retrain:
            Learning_DEbias(
                1,
                0,
                dataset_name,
                50,
                50,
                amount_bias=amount_bias,
                alternative_path=None
            )
        Learning_DEbias(
            0,
            1,
            dataset_name,
            50,
            50,
            amount_bias=amount_bias,
            alternative_path=None
        )
    else:    
        if guidance_strength != 1:
            assert dataset_name == "waterbirds"
            alt_path = f"./data/synthetic/w_{guidance_strength}/imagenet/waterbirds"
            Learning_DEbias(
                1,
                0,
                "waterbirds",
                50,
                50,
                amount_bias=amount_bias,
                alternative_path=alt_path,
            )
            Learning_DEbias(
                0,
                1,
                "waterbirds",
                50,
                50,
                amount_bias=amount_bias,
                alternative_path=alt_path,
            )

        if amount_bias in {10, 20, 30}:
            amount_bias = int(amount_bias)
            assert dataset_name == "waterbirds"
            alt_path = f"./data/synthetic/increasing_rho/waterbirds/{amount_bias}"
            Learning_DEbias(
                1,
                0,
                "waterbirds",
                50,
                50,
                amount_bias=amount_bias,
                alternative_path=alt_path,
            )
            Learning_DEbias(
                0,
                1,
                "waterbirds",
                50,
                50,
                amount_bias=amount_bias,
                alternative_path=alt_path,
            )

    if dataset_name == "unbiased_cifar10":
        Learning_DEbias(1, 1, dataset_name, 50, 50, amount_bias=amount_bias)
        
        
    if dataset_name == "urbancars":
        # Learning_DEbias(1,0,'urbancars',50,50,amount_bias=95)
        Learning_DEbias(1,1,'urbancars',50,100,amount_bias=95)

   
