import torch
import torch.nn as nn
# import torch.nn.functional as F
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json
from sklearn.metrics import f1_score, roc_auc_score
import math 
import numpy as np
import logging
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam, SGD
import torch.nn.functional as F

from transformers import get_scheduler, get_cosine_schedule_with_warmup
from tqdm import tqdm
from models import MLP, End2EndModel

parser = argparse.ArgumentParser(description='Settings for creating CBM')


parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--concept_set", type=str, default=None, 
                    help="path to concept set name")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use")

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='saved_models', help="where to save trained models")
parser.add_argument("--saved_model", type=str, default=None, help='directory to where the desired checkpoint is location')
parser.add_argument("--epoch", type=int, default=100, help="number of training epochs")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--scheme", type=str, default='joint', help='training schema [independent, sequential, joint]')
parser.add_argument("--supervision", default=False, action='store_true')
parser.add_argument("--filtering", type=str, default=None, help="choose from [None, submodular, random, similarity, discriminative, coverage]")
parser.add_argument("--filtering_size", type=int, default=200)
parser.add_argument("--tau", default=0.9, type=float)
parser.add_argument("--reg_alpha", default=0.7, type=float)
parser.add_argument("--init_beta", default=0.8, type=float)
parser.add_argument("--ratio", default=1, type=float)
parser.add_argument("--graph", action='store_false', default=True)
parser.add_argument("--logger", action='store_true')
parser.add_argument("--output_dir", type=str, default='./log/')
parser.add_argument("--use_graph", default=False, action='store_true')
parser.add_argument("--alpha", default=0.1, type=float)
parser.add_argument("--beta", default=0.1, type=float)

def get_ucp(score):
    return 1/torch.abs(F.relu(score) - 0.5)**2

def intervene_on_graph(edge_param):
    component = {
        0: torch.LongTensor(
            [
                [0, 1],
                [1, 0],
                [0, 4],
                [0, 4],
                [1, 4],
                [4, 1],
            ]
        ),
        1: torch.LongTensor(
            [
                
                [2, 3],
                [3, 2],
                [2, 5],
                [5, 2],
                [2, 6],
                [6, 2],
                [3, 5],
                [5, 3], 
                [3, 6],
                [6, 3],
                [5, 6],
                [6, 5],
            ]
        ),
        2: torch.LongTensor(
            [
                
                [7, 8],
                [8, 7],
                [7, 9],
                [9, 7],
                [7, 10],
                [10, 7],
                [8, 9],
                [9, 8],
                [8, 10],
                [10, 8],
                [9, 10],
                [10, 9]
            ]
        )
    }
    num_edges = [6, 12, 12]
    mean_value = torch.mean(edge_param)
    for label, edges in component.items():
        active_values = []
        non_active_edges = []
        for edge in edges:
            i, j = edge
            if edge_param[i, j].item() != 0:
                active_values.append(edge_param[i, j].item())
            else:
                non_active_edges.append((i, j))
        if len(active_values) == 0:
            edge_value = 0
        else:
            edge_value = 0.1 * torch.mean(torch.tensor(active_values).to(edge_param.device))
        for edge in non_active_edges:
            i, j = edge
            edge_param[i, j] = edge_value
            # if random.random() < 0.25:
            #     edge_param[i, j] = edge_value
            #     edge_param[j, i] = edge_value
    
    return edge_param

def main(args):
    if args.logger:
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
        logger = logging.getLogger(__name__)

        # saved_model = args.saved_model.split("_")
        seed = args.seed
        # print(seed)
        if not os.path.exists(f'log/intervention/standard_{args.dataset}_{args.scheme}'):
            os.system(f"mkdir log/intervention/standard_{args.dataset}_{args.scheme}")
        file_name = f'/intervention/standard_{args.dataset}_{args.scheme}/seed_{seed}'
        if args.use_graph:
            file_name += '_graph'
        
        file_name += '.log'
        fh = logging.FileHandler(args.output_dir+file_name, mode="w", encoding="utf-8")
        logger.addHandler(fh)
    else:
        logger = None

    if logger is not None:
        logger.info(f"Intervention this model {args.saved_model}")

    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"

    backbone = args.clip_name.replace('/', '')

    
    if args.dataset == 'chestxpert':
        image_save_name = "{}/{}_BioViL_T.pt".format(args.activation_dir, dataset_train)
        val_image_save_name = "{}/{}_BioViL_T.pt".format(args.activation_dir, dataset_val)
    else:
        image_save_name = "{}/{}_clip_{}.pt".format(args.activation_dir, dataset_train, backbone)
        val_image_save_name = "{}/{}_clip_{}.pt".format(args.activation_dir, dataset_val, backbone)

    concept_save_name = "{}/{}_concepts.pt".format(args.activation_dir, dataset_train)
    val_concept_save_name = "{}/{}_concepts.pt".format(args.activation_dir, dataset_val)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()

    concepts = torch.load(concept_save_name, map_location='cpu').float()
    val_concepts = torch.load(val_concept_save_name, map_location='cpu').float()

    # labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
    # val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

    if args.dataset == 'celebA' or args.dataset == 'chestxpert':
        label_save_name = "{}/{}_labels.pt".format(args.activation_dir, dataset_train)
        val_label_save_name = "{}/{}_labels.pt".format(args.activation_dir, dataset_val)
        labels = torch.LongTensor(torch.load(label_save_name, map_location="cpu"))
        val_labels = torch.LongTensor(torch.load(val_label_save_name, map_location="cpu"))
        if args.dataset == 'celebA':
            num_classes = 256
        else:
            num_classes = 4
    else:
        labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
        val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

        num_classes = len(torch.unique(labels))

    if args.dataset == 'chestxpert':
        filtering_indices = labels != -1
        image_features, concepts, labels = image_features[filtering_indices], concepts[filtering_indices], labels[filtering_indices]

        val_filtering_indices = val_labels != -1
        val_image_features, val_concepts, val_labels = val_image_features[val_filtering_indices], val_concepts[val_filtering_indices], val_labels[val_filtering_indices]

    dataset = TensorDataset(image_features, labels, concepts)
    val_dataset = TensorDataset(val_image_features, val_labels, val_concepts)

    train_loader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size, num_workers=8)
    val_loader = DataLoader(val_dataset, shuffle=False, batch_size=args.batch_size, num_workers=8)

    input_dim = image_features.size(-1)
    # num_classes = len(torch.unique(labels))
    num_concepts = concepts.size(-1)

    model1 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
    model2 = MLP(input_dim=num_concepts, num_classes=num_classes, expand_dim=None)

    if args.use_graph:
        model3 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
        model4 = MLP(input_dim=num_concepts, num_classes=num_classes, expand_dim=None)
    else:
        model3 = None
        model4 = None

    model = End2EndModel(model1, model2, model3=model3, model4=model4, use_relu=True, use_graph=args.use_graph, alpha=args.alpha, beta=args.beta)
    if args.use_graph:
            model.graph_init(num_concepts)
    model.load_state_dict(torch.load(os.path.join(args.saved_model, 'model.pkl')))
    model = model.to(device)

    image_features = image_features.to(device)
    val_image_features = val_image_features.to(device)

    concepts = concepts.to(device)
    val_concepts = val_concepts.to(device)

    labels = labels.to(device)
    val_labels = val_labels.to(device)

    wrong_scores, wrong_targets, right_scores, right_targets = None, None, None, None
    classification = {}
    model.eval()
    intervention = {}
    # for i in tqdm(range(num_classes)):
    #     indices = labels == i
    #     with torch.no_grad():
    #         ith_score = model.get_score(image_features[indices])
    #     ith_dataset = TensorDataset(ith_score, labels[indices], concepts[indices])
    #     ith_dataloader = DataLoader(ith_dataset, batch_size=64, shuffle=True, num_workers=0)
        
    #     intervention[i] = torch.nn.Linear(num_concepts, num_concepts)
    #     torch.nn.init.eye_(intervention[i].weight.data)
    #     intervention[i] = intervention[i].to(device)
    #     optimizer = Adam(intervention[i].parameters(), lr=1e-3)
    #     for epoch in range(100):
    #         for batch in ith_dataloader:
    #             optimizer.zero_grad()
    #             X, y, c = batch[0].to(device), batch[1].to(device), batch[2].to(device)
    #             X = X + intervention[i](X)
    #             loss = F.binary_cross_entropy(F.sigmoid(X), c) #+ F.mse_loss(F.sigmoid(X), c)
    #             loss.backward()
    #             optimizer.step()
        # interventions[i] = ith_intervention

    n, total_loss, overall_acc = 0, 0, 0
    pred_concepts, true_concepts = None, None
    
    val_dataset = TensorDataset(val_image_features, val_labels, val_concepts)
    val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=0)

    score = None
    right_concept, wrong_concept = {}, {}
    wrong_concept_scores, wrong_concept_ground = {}, {}
    right_concept_scores = {}
    
    for batch in val_dataloader:
        X, y, c = batch
        X, y, c = X.to(device), y.to(device), c.to(device)
        batch_size = X.size(0)

        loss, c_pred, y_pred = model(X, y, c)
        c_pred = torch.where(c_pred > 0.5, 1, 0)
        c_score = model.get_score(X)

        acc = (y_pred.max(1)[1]== y).float().mean()
        overall_acc += acc * batch_size
        if pred_concepts is None:
            pred_concepts = c_pred.detach().cpu()
            true_concepts = c.detach().cpu()
        else:
            pred_concepts = torch.cat([pred_concepts, c_pred.detach().cpu()], dim=0)
            true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)
        n += batch_size

        y_pred = y_pred.max(1)[1]
        y_pred = torch.where(y_pred == y, y, -y-1)
        
        if score is None:
            score = c_score
        else:
            score = torch.cat([score, c_score], dim=0)
        
        
        for label in range(num_classes):
            rights = y_pred == label
            wrongs = y_pred == -label-1

            right_c_pred, right_c_ground, right_c_score = c_pred[rights], c[rights], c_score[rights]
            if label not in right_concept:
                right_concept[label] = torch.where(right_c_pred == right_c_ground, right_c_score, 0)
                right_concept_scores[label] = right_c_score 
            else:
                right_concept[label] = torch.cat([right_concept[label],
                torch.where(right_c_pred == right_c_ground, right_c_score, 0)], dim=0)
                right_concept_scores[label] = torch.cat([right_concept_scores[label], right_c_score], dim=0)

            wrong_c_pred, wrong_c_ground, wrong_c_score = c_pred[wrongs], c[wrongs], c_score[wrongs]
            if label not in wrong_concept:
                wrong_concept[label] = torch.where(wrong_c_pred != wrong_c_ground, wrong_c_score, 0)
                wrong_concept_scores[label] = c_score[wrongs]
                wrong_concept_ground[label] = c[wrongs]
            else:
                wrong_concept[label] = torch.cat([wrong_concept[label],
                torch.where(wrong_c_pred != wrong_c_ground, wrong_c_score, 0)], dim=0)
                wrong_concept_scores[label] = torch.cat([wrong_concept_scores[label], c_score[wrongs]], dim=0)
                wrong_concept_ground[label] = torch.cat([wrong_concept_ground[label], c[wrongs]], dim=0)
            
            

        torch.cuda.empty_cache()

    if logger is not None:
        logger.info("Original Validation Acc: {:.4f}, ROC_AUC: {:.4f}".format( overall_acc / n, roc_auc_score(true_concepts, pred_concepts, average='micro')))
    else:
        print("Original Validation Acc: {:.4f}".format( overall_acc / n))
    original_acc = overall_acc / n
    graph_mask = torch.ones((num_concepts, num_concepts)).to(device)
    # print(model.edge_param1 + model.edge_param1.t())
    # edge_param = intervene_on_graph(model.edge_param1 + model.edge_param1.t())
    # print(edge_param)
    # exit()
    edge_params = []

    # model.edge_param1.data += intervene_on_graph(F.relu(model.edge_param1))
    # model.edge_param2.data += intervene_on_graph(F.relu(model.edge_param2))
    # model.edge_param3.data += intervene_on_graph(F.relu(model.edge_param3))
    edge_param1_data = model.edge_param1.data.clone()
    edge_param2_data = model.edge_param2.data.clone()
    edge_param3_data = model.edge_param3.data.clone()
    
    states = {
        'edge_param1': model.edge_param1,
        'edge_param2': model.edge_param2,
        'edge_param3': model.edge_param3,
    }
    for k, v in states.items():
        if 'edge_param' in k:
            edge_params.append(F.relu(0.5 * (v + v.t())))

    edge_param = torch.zeros(edge_params[-1].shape).to(device)
    for param in edge_params:
        edge_param += param

    edge_param = torch.tril(edge_param, -1).T

    edge_index = torch.nonzero(edge_param)
    print('Original number of edges:', edge_index.size(0))
    best_acc = original_acc
    for edge in tqdm(edge_index):
        i, j = edge
        tempo_mask = graph_mask.clone()
        tempo_mask[i, j] = 0
        tempo_mask[j, i] = 0

        n, total_loss, overall_acc = 0, 0, 0
        pred_concepts, true_concepts = None, None
        
        val_dataset = TensorDataset(val_image_features, val_labels, val_concepts)
        val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=0)

        model.edge_param1.data = edge_param1_data * tempo_mask
        model.edge_param2.data = edge_param2_data * tempo_mask
        model.edge_param3.data = edge_param3_data * tempo_mask
        
        for batch in val_dataloader:
            X, y, c = batch
            X, y, c = X.to(device), y.to(device), c.to(device)
            batch_size = X.size(0)

            c_score = model.get_score(X)
            # c_pred, y_pred = model.intervene_on_graph(c_score, tempo_mask)
            c_pred, y_pred = model.intervene(c_score)
            c_pred = torch.where(c_pred > 0.5, 1, 0)

            acc = (y_pred.max(1)[1]== y).float().mean()
            overall_acc += acc * batch_size
            n += batch_size

            y_pred = y_pred.max(1)[1]
            y_pred = torch.where(y_pred == y, y, -y-1)
            
            if score is None:
                score = c_score
            else:
                score = torch.cat([score, c_score], dim=0)
            
            torch.cuda.empty_cache()

        if overall_acc / n >= best_acc:
            # print(overall_acc / n)
            graph_mask[i, j] = 0
            graph_mask[j, i] = 0
            best_acc = overall_acc / n

        # if logger is not None:
        #     logger.info("After intervention on graph Acc: {:.4f}".format( overall_acc / n))
        # else:
        #     print("After intervention on graph Acc: {:.4f}".format( overall_acc / n))
    n, total_loss, overall_acc = 0, 0, 0
    pred_concepts, true_concepts = None, None

    model.edge_param1.data = edge_param1_data * graph_mask
    model.edge_param2.data = edge_param2_data * graph_mask
    model.edge_param3.data = edge_param3_data * graph_mask

    for batch in val_dataloader:
        X, y, c = batch
        X, y, c = X.to(device), y.to(device), c.to(device)
        batch_size = X.size(0)

        c_score = model.get_score(X)
        # c_pred, y_pred = model.intervene_on_graph(c_score, graph_mask)
        c_pred, y_pred = model.intervene(c_score)
        c_pred = torch.where(c_pred > 0.5, 1, 0)

        acc = (y_pred.max(1)[1]== y).float().mean()
        overall_acc += acc * batch_size
        if pred_concepts is None:
            pred_concepts = c_pred.detach().cpu()
            true_concepts = c.detach().cpu()
        else:
            pred_concepts = torch.cat([pred_concepts, c_pred.detach().cpu()], dim=0)
            true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)
        n += batch_size

        y_pred = y_pred.max(1)[1]
        y_pred = torch.where(y_pred == y, y, -y-1)
        
        if score is None:
            score = c_score
        else:
            score = torch.cat([score, c_score], dim=0)
        
        torch.cuda.empty_cache()

    if logger is not None:
        logger.info("After intervention on graph Acc: {:.4f}, ROC_AUC: {:.4f}".format( overall_acc / n, roc_auc_score(true_concepts, pred_concepts, average='micro')))
    else:
        print("After intervention on graph Acc: {:.4f}".format( overall_acc / n))
    
    edge_param *= graph_mask
    print('After filtering number of edges:', torch.nonzero(edge_param).size(0))
    torch.save(torch.nonzero(edge_param).detach(), f'{args.saved_model}/edge.pt')

    # difference = {}
    # ucp_ranks = {}
    # for label in range(num_classes):
    #     right_concept[label] = right_concept[label].mean(dim=0)
    #     wrong_concept[label] = wrong_concept[label].mean(dim=0)
    #     right_concept_scores[label] = right_concept_scores[label].mean(dim=0)
    #     ucp_ranks[label] = torch.sort(get_ucp(right_concept_scores[label]), descending=True)[-1]
    #     difference[label] = right_concept[label] - wrong_concept[label]
    # # print(difference)
    # # exit()

    # score = score.detach().cpu().numpy()
    # ptl_5, ptl_95 = np.percentile(score, 5, axis=0), np.percentile(score, 95, axis=0)
    # ptl_5, ptl_95 = torch.FloatTensor(ptl_5).to(device), torch.FloatTensor(ptl_95).to(device)
    # # print(ptl_5.repeat(128, 1).shape)
    # # exit()
    # size = int(args.ratio * num_concepts)
    # for i in range(num_classes):
    #     difference_concept = difference[i]
    #     wrong_concepts = wrong_concept_scores[i]
    #     if wrong_concepts.size(0) == 0:
    #         continue
    #     wrong_ground = wrong_concept_ground[i]
    #     ucp_mask = torch.ones(num_concepts)
    #     ucp_mask_indices = ucp_ranks[i][size:]
    #     ucp_mask[ucp_mask_indices] = 0
        
        
    #     difference_concept = difference_concept.repeat(wrong_concepts.size(0), 1)
    #     intervention = torch.where(wrong_concepts != wrong_ground, difference_concept, 0)
        
    #     ucp_mask = ucp_mask.to(intervention.device)
    #     intervention = intervention * ucp_mask
        
    #     # print(intervention)
    #     # print(intervention * ucp_mask)
    #     # exit()
    #     # intervention[: ucp_mask] = 0
    #     c_pred, y_pred = model.intervene(wrong_concepts + intervention)
    #     if y_pred.size(0) == 0:
    #         continue
    #     acc = (y_pred.max(1)[1]== i).float().mean()
    #     overall_acc += acc * wrong_concepts.size(0)
    #     # print(overall_acc)
    #     # print('*'*20)
    # if logger is not None:
    #     logger.info("After Intervention Validation Acc: {:.4f}".format( overall_acc / n))
    # else:
    #     print("Original Intervention Acc: {:.4f}".format( overall_acc / n))
    # for i in range(num_classes):
    #     indices = val_labels == i
    #     ith_dataset = TensorDataset(val_image_features[indices], val_labels[indices], val_concepts[indices])
    #     ith_dataloader = DataLoader(ith_dataset, batch_size=64, shuffle=True, num_workers=0)
    #     # ith_intervention = interventions[i]
    #     difference_score = difference[i]
    #     for batch in ith_dataloader:
    #         X, y, c = batch
    #         X, y, c = X.to(device), y.to(device), c.to(device)
    #         batch_size = X.size(0)


    #         X = model.get_score(X)
    #         # difference = c - torch.where(F.sigmoid(X) >= 0.5, 1, 0).to(device)
    #         # tf, fp = difference == 1, difference == -1
            
    #         # X[tf] = ptl_95.repeat(batch_size, 1)[tf]
    #         # X[fp] = ptl_5.repeat(batch_size, 1)[fp]
            
    #         c_pred, y_pred = model.intervene(X)

    #         acc = (y_pred.max(1)[1]== y).float().mean()
    #         overall_acc += acc * batch_size

    #         if pred_concepts is None:
    #             pred_concepts = c_pred.detach().cpu()
    #             true_concepts = c.detach().cpu()
    #         else:
    #             pred_concepts = torch.cat([pred_concepts, c_pred.detach().cpu()], dim=0)
    #             true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)

    #         n += batch_size
            
    #         torch.cuda.empty_cache()
    # pred_concepts = torch.where(pred_concepts > 0.5, 1, 0)
    # if logger is not None:
    #     logger.info("Validation Acc: {:.4f}, F1: {:.4f}, ROC_AUC: {:.4f}".format( 
    #         overall_acc / n, 
    #         f1_score(true_concepts, pred_concepts, average='micro'),
    #         roc_auc_score(true_concepts, pred_concepts, average='micro')))
    # else:
    #     print("Validation Acc: {:.4f}, F1: {:.4f}, ROC_AUC: {:.4f}".format( 
    #         overall_acc / n, 
    #         f1_score(true_concepts, pred_concepts, average='micro'),
    #         roc_auc_score(true_concepts, pred_concepts, average='micro')))


    

if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)