import torch
import torch.nn as nn
# import torch.nn.functional as F
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json
from sklearn.metrics import f1_score, roc_auc_score, jaccard_score
import math 
import numpy as np
import logging
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam, SGD
import torch.nn.functional as F

from transformers import get_scheduler, get_cosine_schedule_with_warmup
from tqdm import tqdm
from models import MLP, End2EndModel

parser = argparse.ArgumentParser(description='Settings for creating CBM')


parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--concept_set", type=str, default=None, 
                    help="path to concept set name")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use")

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='saved_models', help="where to save trained models")
parser.add_argument("--saved_model", type=str, default=None, help='directory to where the desired checkpoint is location')
parser.add_argument("--epoch", type=int, default=100, help="number of training epochs")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--scheme", type=str, default='joint', help='training schema [independent, sequential, joint]')
parser.add_argument("--supervision", default=False, action='store_true')
parser.add_argument("--filtering", type=str, default=None, help="choose from [None, submodular, random, similarity, discriminative, coverage]")
parser.add_argument("--filtering_size", type=int, default=200)
parser.add_argument("--tau", default=0.9, type=float)
parser.add_argument("--reg_alpha", default=0.7, type=float)
parser.add_argument("--init_beta", default=0.8, type=float)
parser.add_argument("--ratio", default=1, type=float)
parser.add_argument("--graph", action='store_false', default=True)
parser.add_argument("--logger", action='store_true')
parser.add_argument("--output_dir", type=str, default='./log/')
parser.add_argument("--use_graph", default=False, action='store_true')
parser.add_argument("--alpha", default=0.1, type=float)
parser.add_argument("--beta", default=0.1, type=float)

def get_ucp(score):
    return 1/torch.abs(F.relu(score) - 0.5)**2

def main(args):
    if args.logger:
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
        logger = logging.getLogger(__name__)

        # saved_model = args.saved_model.split("_")
        seed = args.seed
        # print(seed)
        if not os.path.exists(f'log/intervention/standard_{args.dataset}_{args.scheme}'):
            os.system(f"mkdir log/intervention/standard_{args.dataset}_{args.scheme}")
        file_name = f'/intervention/standard_{args.dataset}_{args.scheme}/seed_{seed}'
        if args.use_graph:
            file_name += '_graph'
        
        file_name += '.log'
        fh = logging.FileHandler(args.output_dir+file_name, mode="w", encoding="utf-8")
        logger.addHandler(fh)
    else:
        logger = None

    if logger is not None:
        logger.info(f"Intervention this model {args.saved_model}")

    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"

    backbone = args.clip_name.replace('/', '')

    
    if args.dataset == 'chestxpert':
        image_save_name = "{}/{}_BioViL_T.pt".format(args.activation_dir, dataset_train)
        val_image_save_name = "{}/{}_BioViL_T.pt".format(args.activation_dir, dataset_val)
    else:
        image_save_name = "{}/{}_clip_{}.pt".format(args.activation_dir, dataset_train, backbone)
        val_image_save_name = "{}/{}_clip_{}.pt".format(args.activation_dir, dataset_val, backbone)

    concept_save_name = "{}/{}_concepts.pt".format(args.activation_dir, dataset_train)
    val_concept_save_name = "{}/{}_concepts.pt".format(args.activation_dir, dataset_val)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()

    concepts = torch.load(concept_save_name, map_location='cpu').float()
    val_concepts = torch.load(val_concept_save_name, map_location='cpu').float()

    # labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
    # val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

    if args.dataset == 'celebA' or args.dataset == 'chestxpert':
        label_save_name = "{}/{}_labels.pt".format(args.activation_dir, dataset_train)
        val_label_save_name = "{}/{}_labels.pt".format(args.activation_dir, dataset_val)
        labels = torch.LongTensor(torch.load(label_save_name, map_location="cpu"))
        val_labels = torch.LongTensor(torch.load(val_label_save_name, map_location="cpu"))
        if args.dataset == 'celebA':
            num_classes = 256
        else:
            num_classes = 4
    else:
        labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
        val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

        num_classes = len(torch.unique(labels))

    if args.dataset == 'chestxpert':
        filtering_indices = labels != -1
        image_features, concepts, labels = image_features[filtering_indices], concepts[filtering_indices], labels[filtering_indices]

        val_filtering_indices = val_labels != -1
        val_image_features, val_concepts, val_labels = val_image_features[val_filtering_indices], val_concepts[val_filtering_indices], val_labels[val_filtering_indices]

    dataset = TensorDataset(image_features, labels, concepts)
    val_dataset = TensorDataset(val_image_features, val_labels, val_concepts)

    train_loader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size, num_workers=8)
    val_loader = DataLoader(val_dataset, shuffle=False, batch_size=args.batch_size, num_workers=8)

    input_dim = image_features.size(-1)
    # num_classes = len(torch.unique(labels))
    num_concepts = concepts.size(-1)

    model1 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
    model2 = MLP(input_dim=num_concepts, num_classes=num_classes, expand_dim=None)

    if args.use_graph:
        model3 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
        model4 = MLP(input_dim=num_concepts, num_classes=num_classes, expand_dim=None)
    else:
        model3 = None
        model4 = None

    model = End2EndModel(model1, model2, model3=model3, model4=model4, use_relu=True, use_graph=args.use_graph, alpha=args.alpha, beta=args.beta)
    if args.use_graph:
            model.graph_init(num_concepts)
    model.load_state_dict(torch.load(os.path.join(args.saved_model, 'model.pkl')))
    model = model.to(device)

    image_features = image_features.to(device)
    val_image_features = val_image_features.to(device)

    concepts = concepts.to(device)
    val_concepts = val_concepts.to(device)

    labels = labels.to(device)
    val_labels = val_labels.to(device)

    wrong_scores, wrong_targets, right_scores, right_targets = None, None, None, None
    classification = {}
    model.eval()
    intervention = {}
    # for i in tqdm(range(num_classes)):
    #     indices = labels == i
    #     with torch.no_grad():
    #         ith_score = model.get_score(image_features[indices])
    #     ith_dataset = TensorDataset(ith_score, labels[indices], concepts[indices])
    #     ith_dataloader = DataLoader(ith_dataset, batch_size=64, shuffle=True, num_workers=0)
        
    #     intervention[i] = torch.nn.Linear(num_concepts, num_concepts)
    #     torch.nn.init.eye_(intervention[i].weight.data)
    #     intervention[i] = intervention[i].to(device)
    #     optimizer = Adam(intervention[i].parameters(), lr=1e-3)
    #     for epoch in range(100):
    #         for batch in ith_dataloader:
    #             optimizer.zero_grad()
    #             X, y, c = batch[0].to(device), batch[1].to(device), batch[2].to(device)
    #             X = X + intervention[i](X)
    #             loss = F.binary_cross_entropy(F.sigmoid(X), c) #+ F.mse_loss(F.sigmoid(X), c)
    #             loss.backward()
    #             optimizer.step()
        # interventions[i] = ith_intervention

    n, total_loss, overall_acc = 0, 0, 0
    pred_concepts, true_concepts = None, None
    
    val_dataset = TensorDataset(val_image_features, val_labels, val_concepts)
    val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=0)

    score = None
    right_concept, wrong_concept = {}, {}
    wrong_concept_scores, wrong_concept_ground = {}, {}
    right_concept_scores = {}

    c_preds = []
    c_trues = []
    
    for batch in val_dataloader:
        X, y, c = batch
        X, y, c = X.to(device), y.to(device), c.to(device)
        batch_size = X.size(0)

        loss, c_pred, y_pred = model(X, y, c)
        c_pred = torch.where(c_pred > 0.5, 1, 0)
        c_score = model.get_score(X)

        acc = (y_pred.max(1)[1]== y).float().mean()
        c_preds.append(c_pred)
        c_trues.append(c)
        overall_acc += acc * batch_size
        n += batch_size

        y_pred = y_pred.max(1)[1]
        y_pred = torch.where(y_pred == y, y, -y-1)
        
        if score is None:
            score = c_score
        else:
            score = torch.cat([score, c_score], dim=0)
        
        
        for label in range(num_classes):
            rights = y_pred == label
            wrongs = y_pred == -label-1

            right_c_pred, right_c_ground, right_c_score = c_pred[rights], c[rights], c_score[rights]
            if label not in right_concept:
                right_concept[label] = torch.where(right_c_pred == right_c_ground, right_c_score, 0)
                right_concept_scores[label] = right_c_score 
            else:
                right_concept[label] = torch.cat([right_concept[label],
                torch.where(right_c_pred == right_c_ground, right_c_score, 0)], dim=0)
                right_concept_scores[label] = torch.cat([right_concept_scores[label], right_c_score], dim=0)

            wrong_c_pred, wrong_c_ground, wrong_c_score = c_pred[wrongs], c[wrongs], c_score[wrongs]
            if label not in wrong_concept:
                wrong_concept[label] = torch.where(wrong_c_pred != wrong_c_ground, wrong_c_score, 0)
                wrong_concept_scores[label] = c_score[wrongs]
                wrong_concept_ground[label] = c[wrongs]
            else:
                wrong_concept[label] = torch.cat([wrong_concept[label],
                torch.where(wrong_c_pred != wrong_c_ground, wrong_c_score, 0)], dim=0)
                wrong_concept_scores[label] = torch.cat([wrong_concept_scores[label], c_score[wrongs]], dim=0)
                wrong_concept_ground[label] = torch.cat([wrong_concept_ground[label], c[wrongs]], dim=0)
            
            

        torch.cuda.empty_cache()

    c_preds = torch.vstack(c_preds).detach().cpu().numpy()
    c_trues = torch.vstack(c_trues).detach().cpu().numpy()

    if logger is not None:
        logger.info("Original Validation Acc: {:.4f}".format( overall_acc / n))
        logger.info("Original ROC-AUC: {:.4f}".format( roc_auc_score(c_trues, c_preds, average='micro')))
        logger.info("Original Jaccard Score: {:.4f}".format( jaccard_score(c_trues, c_preds, average='micro')))
    else:
        print("Original Validation Acc: {:.4f}".format( overall_acc / n))
    exit()
    difference = {}
    ucp_ranks = {}
    for label in range(num_classes):
        right_concept[label] = right_concept[label].mean(dim=0)
        wrong_concept[label] = wrong_concept[label].mean(dim=0)
        right_concept_scores[label] = right_concept_scores[label].mean(dim=0)
        ucp_ranks[label] = torch.sort(get_ucp(right_concept_scores[label]), descending=True)[-1]
        difference[label] = right_concept[label] - wrong_concept[label]
    # print(difference)
    # exit()

    score = score.detach().cpu().numpy()
    ptl_5, ptl_95 = np.percentile(score, 5, axis=0), np.percentile(score, 95, axis=0)
    ptl_5, ptl_95 = torch.FloatTensor(ptl_5).to(device), torch.FloatTensor(ptl_95).to(device)
    # print(ptl_5.repeat(128, 1).shape)
    # exit()
    size = int(args.ratio * num_concepts)
    for i in range(num_classes):
        difference_concept = difference[i]
        wrong_concepts = wrong_concept_scores[i]
        if wrong_concepts.size(0) == 0:
            continue
        wrong_ground = wrong_concept_ground[i]
        ucp_mask = torch.ones(num_concepts)
        ucp_mask_indices = ucp_ranks[i][size:]
        ucp_mask[ucp_mask_indices] = 0
        
        
        difference_concept = difference_concept.repeat(wrong_concepts.size(0), 1)
        intervention = torch.where(wrong_concepts != wrong_ground, difference_concept, 0)
        
        ucp_mask = ucp_mask.to(intervention.device)
        intervention = intervention * ucp_mask
        
        # print(intervention)
        # print(intervention * ucp_mask)
        # exit()
        # intervention[: ucp_mask] = 0
        c_pred, y_pred = model.intervene(wrong_concepts + intervention)
        if y_pred.size(0) == 0:
            continue
        acc = (y_pred.max(1)[1]== i).float().mean()
        overall_acc += acc * wrong_concepts.size(0)
        # print(overall_acc)
        # print('*'*20)
    if logger is not None:
        logger.info("After Intervention Validation Acc: {:.4f}".format( overall_acc / n))
    else:
        print("Original Intervention Acc: {:.4f}".format( overall_acc / n))
    # for i in range(num_classes):
    #     indices = val_labels == i
    #     ith_dataset = TensorDataset(val_image_features[indices], val_labels[indices], val_concepts[indices])
    #     ith_dataloader = DataLoader(ith_dataset, batch_size=64, shuffle=True, num_workers=0)
    #     # ith_intervention = interventions[i]
    #     difference_score = difference[i]
    #     for batch in ith_dataloader:
    #         X, y, c = batch
    #         X, y, c = X.to(device), y.to(device), c.to(device)
    #         batch_size = X.size(0)


    #         X = model.get_score(X)
    #         # difference = c - torch.where(F.sigmoid(X) >= 0.5, 1, 0).to(device)
    #         # tf, fp = difference == 1, difference == -1
            
    #         # X[tf] = ptl_95.repeat(batch_size, 1)[tf]
    #         # X[fp] = ptl_5.repeat(batch_size, 1)[fp]
            
    #         c_pred, y_pred = model.intervene(X)

    #         acc = (y_pred.max(1)[1]== y).float().mean()
    #         overall_acc += acc * batch_size

    #         if pred_concepts is None:
    #             pred_concepts = c_pred.detach().cpu()
    #             true_concepts = c.detach().cpu()
    #         else:
    #             pred_concepts = torch.cat([pred_concepts, c_pred.detach().cpu()], dim=0)
    #             true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)

    #         n += batch_size
            
    #         torch.cuda.empty_cache()
    # pred_concepts = torch.where(pred_concepts > 0.5, 1, 0)
    # if logger is not None:
    #     logger.info("Validation Acc: {:.4f}, F1: {:.4f}, ROC_AUC: {:.4f}".format( 
    #         overall_acc / n, 
    #         f1_score(true_concepts, pred_concepts, average='micro'),
    #         roc_auc_score(true_concepts, pred_concepts, average='micro')))
    # else:
    #     print("Validation Acc: {:.4f}, F1: {:.4f}, ROC_AUC: {:.4f}".format( 
    #         overall_acc / n, 
    #         f1_score(true_concepts, pred_concepts, average='micro'),
    #         roc_auc_score(true_concepts, pred_concepts, average='micro')))


    

if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)