import torch
import clip
from datasets.waterbirds import Waterbirds
from datasets.celebA import CelebA
from datasets.BAR import BAR
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision.utils import make_grid, save_image
import pandas as pd
from tqdm import tqdm
from sklearn import metrics
import numpy as np
import os
from argparse import ArgumentParser
from wandb_wrapper import WandbWrapper

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"

def setup_model(model_name: str = "ViT-B/32"):
    model, preprocess = clip.load(model_name, device)
    return model, preprocess

def waterbirds_setup(preprocess):
    train_set = Waterbirds(env="train", transform=preprocess)
    attributes = [
        "tree", "forest", "foliage", "branch",  "vegetation", "leaves",
        "sea", "ocean", "beach", "waters", "shore", "coastal"
    ]

    positive_prompt = lambda attribute: f"An image of a bird over {attribute}"

    train_loader = DataLoader(train_set, batch_size=512, shuffle=False, pin_memory=True, num_workers=4)
    text_inputs = torch.cat([clip.tokenize(positive_prompt(attr)) for attr in attributes], dim=0).to(device)    

    return train_set, train_loader, text_inputs

def bar_setup(preprocess):
    train_set = BAR(env="train", transform=preprocess)

    classes = [
        "climbing",
        "diving",
        "fishing",
        "racing",
        "throwing",
        "vaulting"
    ]

    attributes = [
        "cliff",
        "underwater",
        "boat",
        "cars",
        "baseball",
        "pole jump"
    ]

    positive_prompt = lambda attribute, cl: f"A photo about {attribute}"
    train_loader = DataLoader(train_set, batch_size=512, shuffle=False, pin_memory=True, num_workers=4)
    text_inputs = torch.cat([clip.tokenize(positive_prompt(attr, cl)) for attr, cl in zip(attributes, classes)], dim=0).to(device)   

    return train_set, train_loader, text_inputs

def celeba_setup(preprocess):    
    train_set = CelebA(root="./data", split="train", transform=preprocess)
    attributes = ["woman", "man"]
    positive_prompt = lambda attribute: f"{attribute}"
    negative_prompt = lambda attribute : f"but not {attribute}"

    train_loader = DataLoader(train_set, batch_size=512, shuffle=False, pin_memory=True, num_workers=4)
    text_inputs = torch.cat([clip.tokenize(positive_prompt(attr)) for attr in attributes], dim=0).to(device)    

    return train_set, train_loader, text_inputs

def clip_zeroshot(model, train_loader, text_inputs):
    model.eval()
    predictions = []
    labels      = []
    activations = []
    with torch.no_grad():
        for x, (y, b), idx in tqdm(train_loader):
            x: torch.Tensor = x.to(device)
            y: torch.Tensor = y.to(device)
            b: torch.Tensor = b.to(device)
            
            image_features: torch.Tensor = model.encode_image(x)
            text_features: torch.Tensor = model.encode_text(text_inputs)
            
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features  /= text_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T)
            activation, prediction = similarity.max(dim=-1)
            # prediction = (prediction != y).long()
            predictions.append(prediction)
            labels.append(y)
            activations.append(activation)
    
    predictions = torch.cat(predictions,    dim=0).cpu().numpy()
    labels      = torch.cat(labels,         dim=0).cpu().numpy()
    activations = torch.cat(activations,    dim=0).cpu().numpy()

    return predictions, labels

def store_pseudolabels_and_stats(predictions: np.ndarray, labels: np.ndarray, dataset_name, train_set, new_metadata_path, wb: WandbWrapper = None):
    class_0_metrics = None
    class_1_metrics = None
    aligned_metrics = None
    conflic_metrics = None
    overall_metrics = None
    
    match dataset_name:
        case "celeba":
            split_df = pd.read_csv("./data/CelebA/list_eval_partition.csv")
            df = pd.read_csv("./data/CelebA/list_attr_celeba.csv", sep=" ").replace(-1, 0)
            train_ids = split_df[split_df["split"] == 0].index.to_numpy()
            df["clip"] = df["Male"].copy()
            df["split"] = split_df["split"].copy()
            df.loc[train_ids, "clip"] = predictions
            df.to_csv(new_metadata_path)

            csv = pd.read_csv(new_metadata_path, header="infer")
            train_set = csv[csv["split"] == 0]

            aligned     = train_set[train_set["Blond_Hair"] == train_set["Male"]]
            conflicting = train_set[train_set["Blond_Hair"] != train_set["Male"]] 

            class_0_metrics = metrics.classification_report(train_set[train_set["Blond_Hair"]==0]["Male"], train_set[train_set["Blond_Hair"]==0]["clip"], target_names=["Aligned", "Conflicting"])
            class_1_metrics = metrics.classification_report(train_set[train_set["Blond_Hair"]==1]["Male"], train_set[train_set["Blond_Hair"]==1]["clip"], target_names=["Conflicting", "Aligned"])
            aligned_metrics = metrics.classification_report(aligned["Male"], aligned["clip"], target_names=["Not Blond", "Blond"])
            conflic_metrics = metrics.classification_report(conflicting["Male"], conflicting["clip"], target_names=["Not Blond", "Blond"])
            overall_metrics = metrics.classification_report(train_set["Male"], train_set["clip"], target_names=["Not Blond", "Blond"])

            print("Class 0 (Non Blond)")
            print(class_0_metrics, "\n")
            print("Class 1 (Blond)")
            print(class_1_metrics, "\n")
            print("Bias-Aligned")
            print(aligned_metrics, "\n")
            print("Bias-Conflicting", "\n")
            print(conflic_metrics, "\n")
            print("Overall")
            print(overall_metrics, "\n")

        
        case "bar":
            predictions = (predictions != labels).astype(np.int32)
            pred_csv = pd.DataFrame(predictions, columns=["clip", ])
            print(torch.unique(torch.from_numpy(predictions), return_counts=True))
            pred_csv.to_csv(new_metadata_path)

        case "waterbirds":
            predictions = np.where(predictions < 6, 0, 1).astype(np.int32)
            df = pd.read_csv("./data/waterbirds/waterbird_complete95_forest2water2/metadata.csv", header="infer", index_col=0)
            train_ids = df[df["split"] == 0].index.to_numpy()
            df["clip"] = df["place"].copy()
            df.loc[train_ids, "clip"] = predictions
            df.to_csv(new_metadata_path)

            csv = pd.read_csv(new_metadata_path, header="infer")
            train_set = csv[csv["split"] == 0]
            aligned     = train_set[train_set["y"] == train_set["place"]]
            conflicting = train_set[train_set["y"] != train_set["place"]]

            class_0_metrics = metrics.classification_report(train_set[train_set["y"]==0]["place"], train_set[train_set["y"]==0]["clip"], target_names=["Aligned", "Conflicting"])
            class_1_metrics = metrics.classification_report(train_set[train_set["y"]==1]["place"], train_set[train_set["y"]==1]["clip"], target_names=["Conflicting", "Aligned"])
            aligned_metrics = metrics.classification_report(aligned["place"], aligned["clip"], target_names=["Landbird", "Waterbird"])
            conflic_metrics = metrics.classification_report(conflicting["place"], conflicting["clip"], target_names=["Landbird", "Waterbird"])
            overall_metrics = metrics.classification_report(train_set["place"], train_set["clip"], target_names=["Landbird", "Waterbird"])

            print("Class 0 (Landbird)")
            print(class_0_metrics, "\n")
            print("Class 1 (Waterbird)")
            print(class_1_metrics, "\n")
            print("Bias-Aligned")
            print(aligned_metrics, "\n")
            print("Bias-Conflicting", "\n")
            print(conflic_metrics, "\n")
            print("Overall")
            print(overall_metrics, "\n")

    if wb is not None:
        wb.log_output({
            "id_metrics": 
                wb.backend.Table(dataframe=pd.DataFrame.from_dict({ 
                "class_0": class_0_metrics,
                "class_1": class_1_metrics,
                "aligned": aligned_metrics,
                "conflic": conflic_metrics,
                "overall": overall_metrics
            }, orient="index"))
        })
        metadata_file = wb.backend.Artifact("pseudolabels_metadata_file", type="dataset")
        metadata_file.add_file(new_metadata_path)
        wb.backend.log_artifact(metadata_file)


parser = ArgumentParser()
parser.add_argument("--dataset", type=str, default="waterbirds", required=True, help="dataset name. choose in [waterbirds, bar, celeba, imagenet-a]")
parser.add_argument("--use_wb", type=str, default="true", help="whether to use weights and biases logging or not, default=true")
parser.add_argument("--model", type=str, default="ViT-B/32", help="Backbone for CLIP, default=ViT-B/32")    

if __name__ == "__main__":
    args = parser.parse_args()
    dataset     = args.dataset
    use_wb      = args.use_wb == "true" if args.use_wb else True
    model_name  = args.model

    if use_wb:
        wb = WandbWrapper(
            project_name="CLIP_Pseudolabeling",
            config = args
        ) 
    else: wb = None

    model, preprocess = setup_model(model_name)
    match dataset:
        case "celeba":
            train_set, train_loader, text_inputs = celeba_setup(preprocess)
        case "waterbirds":
            train_set, train_loader, text_inputs = waterbirds_setup(preprocess)
        case "bar":
            train_set, train_loader, text_inputs = bar_setup(preprocess)

    predictions, labels = clip_zeroshot(model, train_loader, text_inputs)
    store_pseudolabels_and_stats(predictions, labels, dataset, train_set, f"{dataset}_metadata_aug.csv", wb)

    if wb is not None:
        wb.finish()






















