import torch
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json
from sklearn.metrics import f1_score
import math 
import numpy as np
from dataset import datasets, ImbalancedDatasetSampler
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam, SGD
import torch.nn.functional as F
import torch_geometric
from transformers import get_scheduler, get_cosine_schedule_with_warmup
from tqdm import tqdm
from models import GCN, GNN, ConceptGNN
import matplotlib.pyplot as plt
import logging
import networkx as nx
from selection_algo import global_submodular_select

parser = argparse.ArgumentParser(description='Settings for creating CBM')


parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--concept_set", type=str, default=None, 
                    help="path to concept set name")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use")

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=128, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='saved_models', help="where to save trained models")
parser.add_argument("--saved_model", type=str, default=None, help='directory to where the desired checkpoint is location')
parser.add_argument("--epoch", type=int, default=100, help="number of training epochs")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--scheme", type=str, default='joint', help='training schema [independent, sequential, joint]')
parser.add_argument("--supervision", default=False, action='store_true')
parser.add_argument("--filtering", type=str, default=None, help="choose from [None, submodular, random, similarity, discriminative, coverage]")
parser.add_argument("--filtering_size", type=int, default=200)
parser.add_argument("--tau", default=0.9, type=float)
parser.add_argument("--reg_alpha", default=0.7, type=float)
parser.add_argument("--init_beta", default=0.8, type=float)
parser.add_argument("--ratio", default=1, type=float)
parser.add_argument("--graph", action='store_false', default=True)
parser.add_argument("--logger", action='store_true')
parser.add_argument("--output_dir", type=str, default='./log/')

def graph_visualize1(data_set, fig_name, labels=None, edge_indices=None, color=None):
    plt.figure(figsize=(500, 100))

    G = nx.Graph()
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(range(len(labels)))
    pos = nx.drawing.spring_layout(G)
    nx.draw(G, label=labels, with_labels=True, node_size=1000, pos=pos, node_color='#dde5b6')
    # pos = G.pos
    for edge_index, c in zip(edge_indices, color):
        print(c)
        if edge_index is not None:
            if edge_index.size(-1) != 2:
                edge_index = edge_index.t()
            # G.add_edges_from(edge_index.detach().tolist(), color=c)
            nx.draw_networkx_edges(G, pos=pos, edgelist=edge_index.detach().tolist(), edge_color=c)
        else:
            # G.add_edges_from(data.edge_index.t().detach().tolist(), color=c)
            nx.draw_networkx_edges(G, pos=pos, edgelist=data.edge_index.detach().tolist(), edge_color=c)
    
    nx.draw(G, label=labels, with_labels=True, node_size=1000, node_color='#dde5b6')
    plt.savefig('./visualization/'+fig_name+'.pdf',dpi=199)
    return

def graph_visualize(data, fig_name, labels=None, edge_index=None, node_size=1000, font_size=12):
    plt.figure(figsize=(75, 75))

    G = nx.Graph()
    # pos = nx.spring_layout(G, k=0.5/(np.sqrt(len(labels))))
    if labels is None:
        G.add_nodes_from(range(len(data)))
        labels = range(len(data))
    else:
        G.add_nodes_from(labels)

    if edge_index is not None:
        if edge_index.size(-1) != 2:
            edge_index = edge_index.t()
        G.add_edges_from(edge_index.detach().tolist())
        # G.add_edges_from(edge_index.detach().tolist(), pos=pos)
    else:
        G.add_edges_from(data.edge_index.t().detach().tolist(), pos=pos)
    
    nx.draw(G, label=labels, with_labels=True, node_size=node_size, node_color='#bde0fe', font_size=font_size)
    plt.savefig('./visualization/'+fig_name+'.pdf')
    return

def get_ucp(score):
    return 1/torch.abs(F.relu(score) - 0.5)**2

def minimum_spanning_tree(adj_matrix):
    num_nodes = adj_matrix.size(0)
    
    # Calculate degrees
    degrees = adj_matrix.sum(dim=1)
    
    # Find the node with the largest degree
    start_node = torch.argmax(degrees).item()
    
    # Initialize structures
    mst_edges = []
    visited = torch.zeros(num_nodes, dtype=torch.bool)
    node_order = []
    
    def prim_with_custom_selection(start_node):
        current_node = start_node
        visited[current_node] = True
        node_order.append(current_node)

        while True:
            # Get neighbors of the current node
            neighbors = [(neighbor, adj_matrix[current_node, neighbor].item()) 
                         for neighbor in range(num_nodes) 
                         if adj_matrix[current_node, neighbor] > 0 and not visited[neighbor]]
            
            if not neighbors:
                break
            
            # Choose the neighbor with the largest degree
            next_node = max(neighbors, key=lambda x: degrees[x[0]])[0]
            
            # Add the edge to the MST
            mst_edges.append((current_node, next_node, adj_matrix[current_node, next_node].item()))
            
            # Mark the node as visited and update the current node
            visited[next_node] = True
            node_order.append(next_node)
            current_node = next_node

    # Run the custom Prim's algorithm from the starting node
    prim_with_custom_selection(start_node)
    
    # Handle disconnected components
    for node in range(num_nodes):
        if not visited[node]:
            prim_with_custom_selection(node)
    
    return node_order, mst_edges


def main(args):
    if args.logger:
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
        logger = logging.getLogger(__name__)

        # saved_model = args.saved_model.split("_")
        seed = args.seed
        # print(seed)
        if not os.path.exists(f'log/intervention/{args.dataset}'):
            os.system(f"mkdir log/intervention/{args.dataset}")
        file_name = f'/intervention/{args.dataset}/seed_{seed}'
        if 'no_graph' in args.saved_model:
            file_name += "_no_graph"
        
        if args.filtering is not None:
            file_name += f'_{args.filtering}_{args.filtering_size}'
        
        file_name += f'_intervention_ratio_{args.ratio}'
        
        file_name += '.log'
        fh = logging.FileHandler(args.output_dir+file_name, mode="w", encoding="utf-8")
        logger.addHandler(fh)
    else:
        logger = None

    if logger is not None:
        logger.info(f"Intervention this model {args.saved_model}")

    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"

    backbone = args.clip_name.replace('/', '')

    if args.clip_name in ['patch_BioViL_T', 'BioViL_T', 'biomedclip']:
        image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_train, backbone)
        val_image_save_name = "{}/{}_{}.pt".format("saved_activations", dataset_val, backbone)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        if args.clip_name == 'patch_BioViL_T':
            args.clip_name = 'BioViL_T'
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)
    else:
        image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_train, backbone)
        val_image_save_name = "{}/{}_clip_{}.pt".format("saved_activations", dataset_val, backbone)
        concept_set_name = (args.concept_set.split("/")[-1]).split(".")[0]
        text_save_name = "{}/{}_{}.pt".format("saved_activations", concept_set_name, backbone)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()

    text_features = torch.load(text_save_name, map_location="cpu").float()

    labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
    val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

    val_concept_scores = F.normalize(val_image_features, dim=-1) @ F.normalize(text_features, dim=-1).t()

    num_classes = torch.unique(labels)
    num_images_per_class = []
    for i in num_classes:
        num_images_per_class.append((labels == i).float().sum().item())
    num_images_per_class = torch.LongTensor(num_images_per_class)

    if args.filtering == 'submodular':
        text_features, selected = global_submodular_select(image_features, text_features, 1, args.filtering_size, num_images_per_class, [1,1])

    input_dim = val_image_features.size(-1)
    num_classes = len(torch.unique(val_labels))
    num_concepts = text_features.size(0)

    model = ConceptGNN(input_dim=val_image_features.size(-1), output_dim=len(torch.unique(val_labels)), num_concepts=text_features.size(0),
        hidden_dim=text_features.size(-1), module='gcn', tau=args.tau, alpha=args.reg_alpha, beta=args.init_beta)
    weight = F.normalize(text_features, dim=-1) @ F.normalize(text_features, dim=-1).t()
    # # model.init_edge_weight(weight)
    # model.edge_param1.requires_grad = args.learnable_graph
    # model.edge_param2.requires_grad = args.learnable_graph
    # model.edge_param3.requires_grad = args.learnable_graph

    model.load_state_dict(torch.load(os.path.join(args.saved_model, 'model.pkl')))

    model = model.to(device)
    val_set = datasets.ConceptDataset(image_features=val_image_features, text_features=text_features, tau=args.tau,
                                        concept_scores=val_concept_scores, labels = val_labels, graph_construction=False)
    loader = torch_geometric.loader.DataLoader(val_set, batch_size=args.batch_size, shuffle=False)

    n, total_loss, overall_acc = 0, 0, 0
    corrects = 0
    sample = val_set[0]
    model.eval()
    
    
    # graph_visualize(sample['graph'].x, f'val_before', labels=None, edge_index=sample['graph'].edge_index)
    wrong_scores, wrong_targets, right_scores, right_targets = None, None, None, None
    text_features = text_features.to(device)
    classification = {}
    # for instance in tqdm(loader):

    #     x = instance['x'].to(device)
    #     concepts = instance['concept'].to(device)
    #     target = instance['label'].to(device)
    #     image = instance['image'].to(device)
    #     batch_size = x.size(0)
        
    #     loss, pred, c_adj = model(image, target, args.graph, text_features=text_features)

    #     wrong_indices = pred.max(1)[1] != target
    #     right_indices = pred.max(1)[1] == target
    #     score = model.get_score(image, text_features=text_features)
    #     if wrong_scores is None:
    #         wrong_scores = score[wrong_indices, :].detach().cpu()
    #         right_scores = score[right_indices, :].detach().cpu()
    #         wrong_targets = target[wrong_indices].detach().cpu()
    #         right_targets = target[right_indices].detach().cpu()
    #     else:
    #         wrong_scores = torch.cat([wrong_scores, score[wrong_indices, :].detach().cpu()], dim=0)
    #         right_scores = torch.cat([right_scores, score[right_indices, :].detach().cpu()], dim=0)
    #         wrong_targets = torch.cat([wrong_targets, target[wrong_indices].detach().cpu()])
    #         right_targets = torch.cat([right_targets, target[right_indices].detach().cpu()])

    #     acc = (pred.max(1)[1]== target).float().mean()
    #     corrects += (pred.max(1)[1]== target).float().sum()
    #     overall_acc += acc * batch_size

    #     total_loss += loss.item() * batch_size
    #     n += batch_size

    #     torch.cuda.empty_cache()
    # # print(wrong_scores)
    # # exit()
    # if logger is not None:
    #     logger.info("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    # else:
    #     print("Validation Loss: {:.4f}, Acc: {:.4f}".format(total_loss / n, overall_acc / n))
    
    edge_params = []
    states = {
        'edge_param1': model.edge_param1,
        'edge_param2': model.edge_param2,
        'edge_param3': model.edge_param3,
    }
    for k, v in states.items():
        if 'edge_param' in k:
            edge_params.append(F.relu(0.5 * (v + v.t())))

    edge_param = torch.zeros(edge_params[-1].shape).to(device)
    for param in edge_params:
        edge_param += param

    edge_param = torch.tril(edge_param, -1).T

    edge_index = torch.nonzero(edge_param).flatten()
    connected_nodes = torch.unique(edge_index)
    print('# of connected nodes is ', connected_nodes.size(0))

    indices = torch.zeros(num_concepts).bool()
    indices[connected_nodes] = True

    unconnected_nodes = torch.arange(num_concepts)[~indices]
    
    random.shuffle(connected_nodes)
    random.shuffle(unconnected_nodes)
    
    length = int(args.ratio * connected_nodes.size(0))

    connected_nodes = connected_nodes[:length]
    unconnected_nodes = unconnected_nodes[:length]

    n, total_loss, overall_acc = 0, 0, 0
    for instance in tqdm(loader):

        x = instance['x'].to(device)
        concepts = instance['concept'].to(device)
        target = instance['label'].to(device)
        image = instance['image'].to(device)
        batch_size = x.size(0)
        
        score = model.get_score(image, text_features=text_features)
        score[:, connected_nodes] = 0
        
        pred = model.intervention(score, not args.graph)

        acc = (pred.max(1)[1]== target).float().mean()
        overall_acc += acc * batch_size

        n += batch_size

        torch.cuda.empty_cache()

    if logger is not None:
        logger.info("After Masking Connected Nodes Acc: {:.4f}".format(overall_acc / n))
    else:
        print("After Masking Connected Nodes Acc: {:.4f}".format(overall_acc / n))

    n, total_loss, overall_acc = 0, 0, 0
    for instance in tqdm(loader):

        x = instance['x'].to(device)
        concepts = instance['concept'].to(device)
        target = instance['label'].to(device)
        image = instance['image'].to(device)
        batch_size = x.size(0)
        
        score = model.get_score(image, text_features=text_features)
        score[:, unconnected_nodes] = 0
        
        pred = model.intervention(score, not args.graph)

        acc = (pred.max(1)[1]== target).float().mean()
        overall_acc += acc * batch_size

        n += batch_size

        torch.cuda.empty_cache()
    
    if logger is not None:
        logger.info("After Masking Unconnected Nodes Acc: {:.4f}".format(overall_acc / n))
    else:
        print("After Masking Unconnected Nodes Acc: {:.4f}".format(overall_acc / n))
        
    

if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)