import os
from typing import List, Tuple, Union
import torch
from torch.nn.functional import one_hot
import random
import copy
import numpy as np
from tqdm import tqdm
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import models, transforms
from matplotlib import pyplot as plt
from torch.utils import data
from datasets.waterbirds import Waterbirds
from datasets.bar import BAR
from datasets.BFFHQ import BFFHQ
from sklearn.metrics import classification_report
import pandas as pd
from utils.wandb_wrapper import WandbWrapper
import argparse
import torchvision 

from erm_training import train_model_erm

import torch.nn as nn

class FromNpyDataset(torch.utils.data.Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]

        if self.transform:
            x = self.transform(x)
        return x, (y, y), index

    def __len__(self):
        return len(self.data)
    
class FromTorchvisionDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        if transform is not None:
            self.dataset.transform = self.transform
    
    def __getitem__(self, index):
        x, y = self.dataset[index]
        return x, (y, y), index
    
    def __len__(self):
        return len(self.dataset)

    def __repr__(self):
        return str(self.dataset)

def set_seed(seed):
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

class ModelOnDiffused(torch.nn.Module):
    def __init__(self, num_classes, model, weight_path, dataset='waterbirds'):
        super().__init__()
        self.num_classes = num_classes
        self.model: nn.Module = copy.deepcopy(model)
        self.model.to("cuda")

        
        self.transform = transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.BICUBIC)
        self.model.load_state_dict(torch.load(weight_path))
        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, x, target):
        self.model.eval()
        with torch.no_grad():
            x = self.transform(x)
            return self.model(x), self.loss_fn(self.model(x), target)

def get_real_cifar():
    eval_transform = transforms.Compose([
        transforms.Resize((32, 32), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    train_set = torchvision.datasets.CIFAR10(root="./data", train=True, transform=eval_transform, download=True)
    
    # Never actually used
    val_set   = torchvision.datasets.CIFAR10(root="./data", train=True, transform=eval_transform, download=True) 
    
    test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, transform=eval_transform, download=True)

    train_set = FromTorchvisionDataset(train_set)
    val_set   = FromTorchvisionDataset(val_set)
    test_set  = FromTorchvisionDataset(test_set)

    train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)
    val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, pin_memory=True, num_workers=4)
    test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, pin_memory=True, num_workers=4)

    return train_loader, val_loader, test_loader, eval_transform

def create_bias_amplifier():
    model_for_bias = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
    model_for_bias.classifier = torch.nn.Linear(model_for_bias.classifier.in_features, 10)

    model_for_bias.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    model_for_bias = model_for_bias.to("cuda")
    return model_for_bias

def read_DDPM_unbiased_images():
    transform = transforms.Compose([
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
    ])
    address = os.path.join("data", "synthetic", "unbiased", "cifar10")
    files = os.listdir(address)
    x_train = list()
    y_train = list()
    for i in range(0, len(files)):
        if 'npy' in files[i]:
            DDPM_images = np.load(os.path.join(address, files[i]))
            x_train.append(torch.from_numpy(np.moveaxis(DDPM_images, 3, 1)))
            y_train.append(torch.ones(len(DDPM_images)) * int(files[i][-5]))
    
    
    x_train = torch.Tensor(torch.concatenate(x_train, axis=0))
    y_train = torch.Tensor(torch.concatenate(y_train, axis=0)).long()
    dataset_DDPM = FromNpyDataset(x_train, y_train, transform)
    data_loader_DDPM_train = data.DataLoader(dataset_DDPM, batch_size=256, shuffle=True, pin_memory=True, num_workers=4)      

    return dataset_DDPM, data_loader_DDPM_train    

@torch.no_grad()
def extract_misclassified(dataset_name: str, biased_model: nn.Module, dataset: data.Dataset, device="cuda", rho = 95, wb: WandbWrapper = None):
    dataset: Union[BAR, Waterbirds, BFFHQ, FromTorchvisionDataset]
    class_labels: List[torch.Tensor] = []
    bias_labels: List[torch.Tensor] = []
    class_predictions: List[torch.Tensor] = []
    bias_predictions: List[torch.Tensor] = []
    losses: List[torch.Tensor] = []

    dataset.transform = transforms.Compose([
        transforms.Resize(TARGET_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    print(dataset.transform)
    print(dataset)

    dataloader = data.DataLoader(dataset, batch_size=128, shuffle=False, pin_memory=True, num_workers=4)

    x: torch.Tensor  #type-annot
    y: torch.Tensor  #type-annot
    b: torch.Tensor  #type-annot

    biased_model.eval()
    for (x, (y, b), _) in dataloader:
        x = x.to(device)
        y = y.to(device)
        b = b.to(device)

        outs: Tuple[torch.Tensor, torch.Tensor] = biased_model(x, y)
        y_preds: torch.Tensor = outs[0].argmax(dim=1)
        bias_preds: torch.Tensor = y_preds.clone()

        class_labels.append(y)
        bias_labels.append(b)
        class_predictions.append(y_preds)
        bias_predictions.append(bias_preds)
        losses.append(outs[1])

    class_labels: torch.Tensor = torch.cat(class_labels, dim=0).cpu()
    bias_labels: torch.Tensor = torch.cat(bias_labels, dim=0).cpu()
    class_predictions: torch.Tensor = torch.cat(class_predictions, dim=0).cpu()
    bias_predictions: torch.Tensor = torch.cat(bias_predictions, dim=0).cpu()
    losses: torch.Tensor = torch.cat(losses, dim=0).cpu()

    threshold = torch.mean(losses) + 3*torch.std(losses)
    print(classification_report(class_labels, class_predictions))

    outpath = os.path.join("outputs", dataset_name)
    os.makedirs(outpath, exist_ok=True)
    torch.save({
        "class_labels": class_labels,
        "bias_labels": bias_labels,
        "class_predictions": class_predictions,
        "bias_predictions": bias_predictions,
    }, os.path.join(outpath, f"{dataset_name}_biased_model_predictions.pt"))


    new_metadata_path = None        
    bias_predictions = ((class_predictions != class_labels) & (losses > threshold)).long().numpy()
    pred_csv = pd.DataFrame(bias_predictions, columns=["ddb", ])
    print(torch.unique(torch.from_numpy(bias_predictions), return_counts=True))
    for c in range(10):
        print("Class ", c)
        print(classification_report(bias_labels[class_labels == c], bias_predictions[class_labels == c]))
    new_metadata_path = os.path.join("outputs", f"cifar10_unbiased_metadata_aug.csv")
    pred_csv.to_csv(new_metadata_path)        

    return bias_predictions

def extract_groups(args: argparse.Namespace):
    dataset_name  = args.dataset
    retrain_model = args.retrain
    norm_mode     = args.norm_mode
    use_wb        = args.no_wb 
    rho           = args.rho
    subset_size_pct = args.subset_size

    train_loader, val_loader, test_loader, eval_transform = get_real_cifar()
    config = {"dataset_name": dataset_name, }
    train_set = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        transform=eval_transform,
        download=True
    )

    train_set = FromTorchvisionDataset(train_set)
    
    config["lr"] = 0.0005
    config["num_classes"] = 10
    config["epochs"] = 50
    config["wd"] = 0.01

    wb = None
    if use_wb:
        wb = WandbWrapper("BaselineUnbiased_DiffuseDebias_0", config)

    baseline_model = make_baseline_model()
    train_model_erm(
        baseline_model,
        train_loader, 
        val_loader, 
        test_loader, 
        "cuda", 
        torch.optim.AdamW(baseline_model.parameters(), lr=config["lr"], weight_decay=config["wd"]),
        num_classes=config["num_classes"],
        epochs=config["epochs"],
        wb=wb,
        name=f"baseline_model_unbiased_{dataset_name}"
    )

    biased_model = create_bias_amplifier()
    _, data_loader_DDPM_train = read_DDPM_unbiased_images()

    wb = None
    if use_wb:
        wb = WandbWrapper("DiffuseDebias_0", config)

    print(train_set)
    if retrain_model:
        train_model_erm(
            biased_model,
            data_loader_DDPM_train,
            val_loader,
            None,
            "cuda",
            torch.optim.AdamW(biased_model.parameters(), lr=config["lr"], weight_decay=config["wd"]),
            num_classes=config["num_classes"],
            epochs=config["epochs"],
            wb=wb,
            name=f"biased_model_unbiased_{dataset_name}"
        )

    biased_model = ModelOnDiffused(
        config["num_classes"],
        biased_model,
        os.path.join(PATH_TO_MODELS, f"biased_model_unbiased_{dataset_name}-final.pt"),
        dataset=dataset_name
    )
    len(train_set)
    bias_pseudolabels: torch.Tensor = extract_misclassified(dataset_name, biased_model, train_set, "cuda", rho=rho, wb=wb)
    print(torch.unique(torch.as_tensor(bias_pseudolabels), return_counts=True))

def make_baseline_model():
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model.loss_fn = nn.CrossEntropyLoss(reduction="mean")
    model.to("cuda")
    return model 

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="cifar10", help="dataset name. choose in [cifar10]")
parser.add_argument("--retrain", action="store_true", help="repeat experiment and overwrite biased model, default=False")
parser.add_argument("--no_wb", action="store_false", default=True, help="disables Weights & Biases logging")
parser.add_argument("--norm_mode", type=str, default="imagenet", help="Normalization protocol for generated images. choose in [imagenet, minmax]. Default: imagenet")
parser.add_argument("--seed", type=int, default=0, help="random state for stochastic operations")
parser.add_argument("--rho", type=float, default=95)
parser.add_argument("--subset_size", type=float, default=None, help="Subset size of synthetic images for training the biased model")
parser.add_argument("--guidance_strength", type=int, default=1, help="Ablation study on classifier strength 'w'. Default=1, choose among [0, 1, 2, 3, 5]")

if __name__ == "__main__":   
    PATH_TO_MODELS = f'saved_models'
    TARGET_SIZE = (32, 32)
    args = parser.parse_args()
    set_seed(args.seed)    
    
    print(args)
    extract_groups(args)


