import os
from typing import List, Tuple, Union
import torch
from torch.nn.functional import one_hot
import random
import copy
import numpy as np
from tqdm import tqdm
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import models, transforms
from matplotlib import pyplot as plt
from torch.utils import data
from datasets.waterbirds import Waterbirds
from datasets.imagenet9 import Imagenet9
from datasets.imagenetA import ImageNetA
from datasets.bar import BAR
from datasets.BFFHQ import BFFHQ
from sklearn.metrics import classification_report
import pandas as pd
from utils.wandb_wrapper import WandbWrapper
import argparse
from utils.metrics import *
from tqdm import tqdm

import torch.nn as nn

def set_seed(seed):
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def load_model(path):
    model = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
    model.classifier = nn.Linear(model.classifier.in_features, 9)
    model.load_state_dict(torch.load(path))
    model.to("cuda")
    
    return model

def load_gdro_model():
    gdro_model_path = "./logs/imnetarizona/best_model_59.pth"
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 9)
    model = torch.load(gdro_model_path)
    # model.load_state_dict(torch.load(gdro_model_path))
    model.to("cuda")
    
    return model

@torch.no_grad()
def evaluate_model(model, test_loader, criterion, epoch, device, wb, prefix="val"):
    model.eval()
    loss_task_tot   : AverageMeter = AverageMeter()
    top1            : AverageMeter = AverageMeter()
    
    tk0 = tqdm(
        test_loader, total=int(len(test_loader)), leave=True, dynamic_ncols=True
    )
    
    for batch, (dat, labels, _) in enumerate(tk0):
        dat     : torch.Tensor = dat.to(device)
        target  : torch.Tensor = labels[0].to(device)
        bias_t  : torch.Tensor = labels[1].to(device)
        output  : torch.Tensor = model(dat)
        
        loss    : torch.Tensor = criterion(output, target)       
        loss_task_tot.update(loss.item(), dat.size(0))
        
        acc1 = accuracy(output, target, topk=(1, ))        
        top1.update(acc1[0], dat.size(0))
        acc1  = top1.avg
        avg_loss = loss_task_tot.avg


        postifix_dict = {
            "epoch": epoch,
            f"{prefix}_acc1": top1.avg,
        }
        postifix_dict[f"{prefix}_loss"] = loss_task_tot.avg
        
        iter_string = f"{prefix} Set Epoch {epoch}: \n"
        for key in postifix_dict.keys():
            iter_string += f"{key}:\t {postifix_dict[key]}\n"                        

        # tk0.write(iter_string)
        tk0.set_postfix(postifix_dict)
        
    if wb is not None:
        wb.log_output(postifix_dict)

    return avg_loss

if __name__ == "__main__":
    set_seed(0)
    # baseline_model_path = os.path.join("saved_models", "biased_model_95_imagenet9-final_original_trset.pt")    
    # ba_model_vanilla_path = os.path.join("saved_models", "biased_model_95_imagenet9-final.pt")
    
    # baseline_model = load_model(baseline_model_path)
    # ba_model_vanilla = load_model(ba_model_vanilla_path)
    gdro_model = load_gdro_model()  
    
    test_set = ImageNetA()
    
    test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=4)
    
    # evaluate_model(
    #     baseline_model, 
    #     test_loader,
    #     criterion=nn.CrossEntropyLoss(),
    #     epoch="baseline-model",
    #     device="cuda",
    #     wb=None,
    #     prefix="imagenetA"
    # )
    
    # evaluate_model(
    #     ba_model_vanilla, 
    #     test_loader,
    #     criterion=nn.CrossEntropyLoss(),
    #     epoch="ba_vanilla",
    #     device="cuda",
    #     wb=None,
    #     prefix="imagenetA"
    # )
    
    evaluate_model(
        gdro_model, 
        test_loader,
        criterion=nn.CrossEntropyLoss(),
        epoch="ba_ddpm+gdro",
        device="cuda",
        wb=None,
        prefix="imagenetA"
    )
    
    
    

