import torch
import os
import random
import utils
import data_utils
import similarity
import argparse
import datetime
import json
from sklearn.metrics import f1_score, roc_auc_score, jaccard_score
import math 
import numpy as np
import logging
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.optim import Adam, SGD
import torch.nn.functional as F

from transformers import get_scheduler, get_cosine_schedule_with_warmup
from tqdm import tqdm
from models import MLP, End2EndModel

from cem_models import ConceptEmbeddingModel, GraphConceptEmbeddingModel
from scbm_models import SCBM, SCBLoss, GSCBM
import pytorch_lightning as pl

parser = argparse.ArgumentParser(description='Settings for creating CBM')


parser.add_argument("--dataset", type=str, default="cifar10")
parser.add_argument("--clip_name", type=str, default="ViT-B/16", help="Which CLIP model to use")

parser.add_argument("--device", type=str, default="cuda", help="Which device to use")
parser.add_argument("--batch_size", type=int, default=512, help="Batch size used when saving model/CLIP activations")
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument("--activation_dir", type=str, default='saved_activations', help="save location for backbone and CLIP activations")
parser.add_argument("--save_dir", type=str, default='cbm_supervision_outputs1', help="where to save trained models")
parser.add_argument("--epoch", type=int, default=100, help="number of training epochs")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--scheme", type=str, default='joint', help='training schema [independent, sequential, joint]')
parser.add_argument("--supervision", default=False, action='store_true')
parser.add_argument("--use_graph", default=False, action='store_true')
parser.add_argument("--alpha", default=0.1, type=float)
parser.add_argument("--beta", default=0.1, type=float)

def main(args):

    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S',
                            level=logging.INFO)
    logger = logging.getLogger(__name__)

    if args.supervision:
        saved_path = f"/home/hxu2/CBM-Graph/{args.save_dir}/{args.dataset}_{args.scheme}_supervised_seed_{args.seed}_lr_{args.lr}_epoch_{args.epoch}_alpha_{args.alpha}_beta_{args.beta}"
    else:
        saved_path = f"/home/hxu2/CBM-Graph/{args.save_dir}/{args.dataset}_{args.scheme}_seed_{args.seed}_lr_{args.lr}_epoch_{args.epoch}_alpha_{args.alpha}_beta_{args.beta}"
    
    if args.use_graph:
        saved_path += '_graph'
    print(saved_path)
    if not os.path.exists(saved_path):
        os.system(f'mkdir {saved_path}')
    file_name = saved_path + '/training.log'
    fh = logging.FileHandler(file_name, mode="w", encoding="utf-8")
    logger.addHandler(fh)

    dataset_train = args.dataset + "_train"
    dataset_val = args.dataset + "_val"

    backbone = args.clip_name.replace('/', '')

    if args.dataset == 'celebA':
        image_save_name = "{}/{}_RN101.pt".format(args.activation_dir, dataset_train)
        val_image_save_name = "{}/{}_RN101.pt".format(args.activation_dir, dataset_val)
    elif args.dataset == 'chestxpert':
        image_save_name = "{}/{}_BioViL_T.pt".format(args.activation_dir, dataset_train)
        val_image_save_name = "{}/{}_BioViL_T.pt".format(args.activation_dir, dataset_val)
    elif args.clip_name == 'RN18':
        image_save_name = "{}/{}_{}.pt".format(args.activation_dir, dataset_train, backbone)
        val_image_save_name = "{}/{}_{}.pt".format(args.activation_dir, dataset_val, backbone)
    else:
        image_save_name = "{}/{}_clip_{}.pt".format(args.activation_dir, dataset_train, backbone)
        val_image_save_name = "{}/{}_clip_{}.pt".format(args.activation_dir, dataset_val, backbone)

    concept_save_name = "{}/{}_concepts.pt".format(args.activation_dir, dataset_train)
    val_concept_save_name = "{}/{}_concepts.pt".format(args.activation_dir, dataset_val)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    image_features = torch.load(image_save_name, map_location="cpu").float()
    val_image_features = torch.load(val_image_save_name, map_location="cpu").float()

    concepts = torch.load(concept_save_name, map_location='cpu').float()
    val_concepts = torch.load(val_concept_save_name, map_location='cpu').float()

    if args.dataset == 'celebA' or args.dataset == 'chestxpert':
        label_save_name = "{}/{}_labels.pt".format(args.activation_dir, dataset_train)
        val_label_save_name = "{}/{}_labels.pt".format(args.activation_dir, dataset_val)
        labels = torch.LongTensor(torch.load(label_save_name, map_location="cpu"))
        val_labels = torch.LongTensor(torch.load(val_label_save_name, map_location="cpu"))
        if args.dataset == 'celebA':
            num_classes = 256
        else:
            num_classes = 4
    else:
        labels = torch.LongTensor(data_utils.get_targets_only(dataset_train))
        val_labels = torch.LongTensor(data_utils.get_targets_only(dataset_val))

        num_classes = len(torch.unique(labels))

    if args.dataset == 'chestxpert':
        filtering_indices = labels != -1
        image_features, concepts, labels = image_features[filtering_indices], concepts[filtering_indices], labels[filtering_indices]

        val_filtering_indices = val_labels != -1
        val_image_features, val_concepts, val_labels = val_image_features[val_filtering_indices], val_concepts[val_filtering_indices], val_labels[val_filtering_indices]

    dataset = TensorDataset(image_features, labels, concepts)
    val_dataset = TensorDataset(val_image_features, val_labels, val_concepts)

    train_loader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size, num_workers=1)
    val_loader = DataLoader(val_dataset, shuffle=False, batch_size=args.batch_size, num_workers=1)

    input_dim = image_features.size(-1)
    num_concepts = concepts.size(-1)

    if args.scheme == 'joint':

        if args.supervision:
            model1 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
            model2 = MLP(input_dim=num_concepts, num_classes=num_classes, expand_dim=None)
        else:
            model1 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
            model2 = MLP(input_dim=input_dim, num_classes=num_classes, expand_dim=None)

        if args.use_graph:
            model3 = MLP(input_dim=input_dim, num_classes=num_concepts, expand_dim=input_dim)
            model4 = MLP(input_dim=num_concepts, num_classes=num_classes, expand_dim=None)
        else:
            model3 = None
            model4 = None

        model = End2EndModel(model1, model2, model3=model3, model4=model4, use_relu=True, use_graph=args.use_graph, alpha=args.alpha, beta=args.beta)
        if args.use_graph:
            model.graph_init(num_concepts)

        model = model.to(device)
        model.train()

        opt = Adam(model.parameters(), lr=args.lr)
        if args.scheduler in ['cosine', "polynomial", "inverse_sqrt", "reduce_lr_on_plateau"]:
            num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
            num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
            scheduler = get_scheduler(args.scheduler, opt, 
                num_warmup_steps=num_warmup_steps, 
                num_training_steps=num_training_steps,
            )
        elif args.scheduler == "cosine_with_restarts":
            num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
            num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
            scheduler = get_cosine_schedule_with_warmup(
                opt,
                num_warmup_steps=num_warmup_steps, 
                num_training_steps=num_training_steps,
                num_cycles=args.num_cycles,
            )
        elif args.scheduler == 'step':
            scheduler = StepLR(opt, step_size=10, gamma=0.99)
        else:
            scheduler = None
        
        for epoch in tqdm(range(args.epoch)):
            n, total_loss, overall_acc = 0, 0, 0
            pred_concepts, true_concepts = None, None
            for batch in train_loader:
                opt.zero_grad()

                X, y, c = batch
                # print(X.shape, y.shape, c.shape)
                X, y, c = X.to(device), y.to(device), c.to(device)
                batch_size = X.size(0)

                # print(model.edge_param3)

                if args.supervision:
                    loss, c_pred, y_pred = model(X, y, c)
                else:
                    loss, c_pred, y_pred = model(X, y)
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                
                loss.backward()

                opt.step()
                if scheduler is not None:
                    scheduler.step()

                acc = (y_pred.max(1)[1]== y).float().mean()
                overall_acc += acc * batch_size

                if pred_concepts is None:
                    pred_concepts = c_pred.detach().cpu()
                    true_concepts = c.detach().cpu()
                else:
                    pred_concepts = torch.cat([pred_concepts, c_pred.detach().cpu()], dim=0)
                    true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)

                total_loss += loss.item() * batch_size
                n += batch_size
                
                torch.cuda.empty_cache()
            if (epoch+1) % 10 == 0:
                # print(f1_score(true_concepts, pred_concepts, average=None))
                logger.info("Epoch {} Loss: {:.4f}, Acc: {:.4f}, Jaccard Score".format(epoch+1, 
                total_loss / n, 
                overall_acc / n, 
                jaccard_score(true_concepts, pred_concepts, average='micro')))
        
        model.eval()
        n, total_loss, overall_acc = 0, 0, 0
        pred_concepts, true_concepts = None, None
        for batch in val_loader:

            with torch.no_grad():
                X, y, c = batch
                X, y, c = X.to(device), y.to(device), c.to(device)
                batch_size = X.size(0)

                if args.supervision:
                    loss, c_pred, y_pred = model(X, y, c)
                else:
                    loss, c_pred, y_pred = model(X, y)

                acc = (y_pred.max(1)[1]== y).float().mean()
                overall_acc += acc * batch_size

                if pred_concepts is None:
                    pred_concepts = c_pred.detach().cpu()
                    true_concepts = c.detach().cpu()
                else:
                    pred_concepts = torch.cat([pred_concepts, c_pred.detach().cpu()], dim=0)
                    true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)

                total_loss += loss.item() * batch_size
                n += batch_size
            
            torch.cuda.empty_cache()
        pred_concepts = torch.where(pred_concepts > 0.5, 1, 0)
        overall_concept_acc = torch.sum(torch.where(torch.norm(true_concepts - pred_concepts, dim=-1) == 0, 1, 0)) / true_concepts.size(0)
        logger.info("Validation Loss: {:.4f}, Acc: {:.4f}, Jaccard Score".format(epoch+1, 
                total_loss / n, 
                overall_acc / n, 
                jaccard_score(true_concepts, pred_concepts, average='micro')))
            
    elif args.scheme == 'cem':
        if args.use_graph:
            model =  GraphConceptEmbeddingModel(
                n_concepts=num_concepts,
                n_tasks=num_classes,
                feat_size=input_dim,
                emb_size=128,
                concept_loss_weight=1,
                auxiliary_loss_weight=1,
                cl_loss_weight=args.alpha,
                l1_regularizer_weight=args.beta,
                optimizer='adam',
                learning_rate=args.lr,
                training_intervention_prob=0.25,
                c_extractor_arch=None,
                c2y_model=None,
            )
        else:
            model = ConceptEmbeddingModel(
                n_concepts=num_concepts,
                n_tasks=num_classes,
                feat_size=input_dim,
                emb_size=128,
                concept_loss_weight=1,
                optimizer='adam',
                learning_rate=args.lr,
                training_intervention_prob=0.25,
                c_extractor_arch=None,
                c2y_model=None,
            )

        trainer = pl.Trainer(
            accelerator="gpu",  # or "cpu" if no GPU available
            devices="auto",
            max_epochs=args.epoch,
            check_val_every_n_epoch=10,
        )

        trainer.fit(model, train_loader, val_loader)
    elif args.scheme == 'scbm':
        if args.use_graph:
            model = GSCBM(
                n_features=input_dim,
                num_concepts=num_concepts,
                num_classes=num_classes,
            )
        else:
            model = SCBM(
                n_features=input_dim,
                num_concepts=num_concepts,
                num_classes=num_classes,
            )

        model = model.to(device)
        model.train()

        opt = Adam(model.parameters(), lr=args.lr)
        if args.scheduler in ['cosine', "polynomial", "inverse_sqrt", "reduce_lr_on_plateau"]:
            num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
            num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
            scheduler = get_scheduler(args.scheduler, opt, 
                num_warmup_steps=num_warmup_steps, 
                num_training_steps=num_training_steps,
            )
        elif args.scheduler == "cosine_with_restarts":
            num_warmup_steps = (image_features.size(0) // args.batch_size) * 5
            num_training_steps = math.ceil(image_features.size(0) / args.batch_size) * args.epoch
            scheduler = get_cosine_schedule_with_warmup(
                opt,
                num_warmup_steps=num_warmup_steps, 
                num_training_steps=num_training_steps,
                num_cycles=args.num_cycles,
            )
        elif args.scheduler == 'step':
            scheduler = StepLR(opt, step_size=10, gamma=0.99)
        else:
            scheduler = None

        loss_fct = SCBLoss(num_classes)

        for epoch in tqdm(range(args.epoch)):
            n, total_loss, overall_acc = 0, 0, 0
            pred_concepts, true_concepts = None, None
            for batch in train_loader:
                opt.zero_grad()

                X, y, c = batch
                # print(X.shape, y.shape, c.shape)
                X, y, c = X.to(device), y.to(device), c.to(device)
                batch_size = X.size(0)

                # print(model.edge_param3)

                c_pred, c_cov, y_pred, cl_loss, l1_regularizer = model(X, epoch, c_true=c)

                target_loss, concepts_loss, prec_loss, loss = loss_fct(
                    concepts_mcmc_probs=c_pred,
                    concepts_true=c,
                    target_pred_logits=y_pred,
                    target_true=y,
                    c_triang_cov=c_cov,
                )

                if cl_loss is not None:
                    loss += args.alpha * cl_loss + args.beta * l1_regularizer
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                
                loss.backward()

                opt.step()
                if scheduler is not None:
                    scheduler.step()

                acc = (y_pred.max(1)[1]== y).float().mean()
                overall_acc += acc * batch_size

                if pred_concepts is None:
                    pred_concepts = c_pred.mean(-1).detach().cpu()
                    true_concepts = c.detach().cpu()
                else:
                    pred_concepts = torch.cat([pred_concepts, c_pred.mean(-1).detach().cpu()], dim=0)
                    true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)

                total_loss += loss.item() * batch_size
                n += batch_size
                
                torch.cuda.empty_cache()
            
            
            if (epoch+1) % 10 == 0:
                pred_concepts = torch.where(pred_concepts > 0.5, 1, 0)
                logger.info("Epoch {} Loss: {:.4f}, Acc: {:.4f}, Jaccard Score {:.4f}".format(epoch+1, 
                total_loss / n, 
                overall_acc / n, 
                jaccard_score(true_concepts, pred_concepts, average='micro')))
        
        model.eval()
        n, total_loss, overall_acc = 0, 0, 0
        pred_concepts, true_concepts = None, None
        for batch in val_loader:

            with torch.no_grad():
                X, y, c = batch
                X, y, c = X.to(device), y.to(device), c.to(device)
                batch_size = X.size(0)

                c_pred, c_cov, y_pred, _, _ = model(X, epoch, c_true=c)

                acc = (y_pred.max(1)[1]== y).float().mean()
                overall_acc += acc * batch_size

                if pred_concepts is None:
                    pred_concepts = c_pred.mean(-1).detach().cpu()
                    true_concepts = c.detach().cpu()
                else:
                    pred_concepts = torch.cat([pred_concepts, c_pred.mean(-1).detach().cpu()], dim=0)
                    true_concepts = torch.cat([true_concepts, c.detach().cpu()], dim=0)

                total_loss += loss.item() * batch_size
                n += batch_size
            
            torch.cuda.empty_cache()
        pred_concepts = torch.where(pred_concepts > 0.5, 1, 0)
        overall_concept_acc = torch.sum(torch.where(torch.norm(true_concepts - pred_concepts, dim=-1) == 0, 1, 0)) / true_concepts.size(0)
        logger.info("Validation Loss: {:.4f}, Acc: {:.4f}, Jaccard Score  {:.4f}".format(
                total_loss / n, 
                overall_acc / n, 
                jaccard_score(true_concepts, pred_concepts, average='micro')))



    if args.use_graph:
            c_adj = 0.5 * (model.edge_param3 + model.edge_param3.t())
            c_adj = model.get_characteristic_matrix(c_adj)
            edge_index = torch.nonzero(torch.triu(c_adj, diagonal=1))
            logger.info(f'Number of edeges is {edge_index.size(0)}, and number of nodes is {c_adj.size(-1)}')

    

    
    torch.save(model.state_dict(), os.path.join(saved_path, 'model.pkl'))

if __name__=='__main__':
    args = parser.parse_args()

    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

    main(args)