import os
import pickle

import torch
import numpy as np

from omegaconf import OmegaConf
from loguru import logger
from tqdm import tqdm
from scipy.stats import gaussian_kde

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import to_hex

from utils.model_utils import *
from utils.utils import *
from utils.dataset_loader_mini import (
    SubgraphIsomorphismDataset as SubgraphIsomorphismDataset_mini,
    GraphEditDistanceDataset as GraphEditDistanceDataset_mini,
    UneqGraphEditDistanceDataset as UneqGraphEditDistanceDataset_mini,
)
from utils.training_utils import pairwise_ranking_loss

## FUNCTIONS FOR TRAINING AND PLOTTING THE DISTRIBUTION EXPT

## Associated config is distn.yaml

def latexify():
    matplotlib.rcParams['text.usetex'] = True
    matplotlib.rcParams['text.latex.preamble'] = r"\usepackage{bm} \usepackage{amsmath,amsfonts} \boldmath"
    matplotlib.rcParams['axes.spines.right'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['legend.frameon'] = False  # Remove the border from the legend
    plt.rc('font', family='serif', weight='bold')
    plt.rc('xtick', labelsize=30)
    plt.rc('ytick', labelsize=30)
    plt.rc('axes', linewidth=1, labelsize=32, labelweight='bold')
    plt.rc('legend', fontsize=26, handlelength=2, loc='upper right')
 

def get_all_embeds(model,dataset):
    model.eval()
    model.fetch_embed = True
    with torch.no_grad():
        corpus_embeds = model(dataset._pack_batch_1d(dataset.corpus_graphs),dataset.corpus_graph_node_sizes,dataset.corpus_graph_edge_sizes)
        query_embeds = model(dataset._pack_batch_1d(dataset.query_graphs),dataset.query_graph_node_sizes,dataset.query_graph_edge_sizes)
    model.train()
    model.fetch_embed = False
    return corpus_embeds,query_embeds

def get_embeds(model,dataset):
    model.eval()
    model.fetch_embed = True
    with torch.no_grad():
        embeds = model(dataset._pack_batch_1d(dataset.corpus_graphs),dataset.corpus_graph_node_sizes,dataset.corpus_graph_edge_sizes)
    model.train()
    model.fetch_embed = False   
    return embeds

## FUNCTION FOR TRAINING THE MODEL
def train_model_distn_expt(conf,gmn_config,every=1):
    
    if conf.dataset.rel_mode == "sub_iso":
        dataset = SubgraphIsomorphismDataset_mini(conf,mode="random")
    elif conf.dataset.rel_mode == "ged":
        dataset = GraphEditDistanceDataset_mini(conf,mode="random") 
    elif conf.dataset.rel_mode == "uneq_ged":
        dataset = UneqGraphEditDistanceDataset_mini(conf,mode="random")
    conf.actual_max_node_set_size = dataset.max_node_set_size
    conf.actual_max_edge_set_size = dataset.max_edge_set_size
    
    model_class = get_class(f"{conf.model.classPath}.{conf.model.name}")
    
    non_exchbl:bool = False
    ### New: add non-exchangeable strategy
    try:
        non_exchbl = (conf.model.non_exchbl!=None)
    except:
        pass
    if non_exchbl:
        conf.task.name = f"{conf.task.name}_{conf.model.non_exchbl}"
    
    # previous run exists
    if os.path.exists(f"{conf.base_dir}/{conf.task.name}/checkpoint_info_{conf.task.name}.pkl"):
        status = pickle.load(open(f"{conf.write_dir}/data/{conf.task.name}/checkpoint_info.pkl","rb"))
        start_val = status["model"] if status["epoch"] < conf.expt.num_epochs else status["model"]+1   
    else:
        status = {"model":0, "losses":{}}
        start_val = 0
        
    graphs,graph_node_sizes,graph_edge_sizes,labels = dataset.fetch()
        
    for model_num in tqdm(range(start_val,conf.expt.num_models)):
        # initialize model and optimiser
        conf.training.seed = model_num
        # use the new seed for model initialization
        set_seed(conf.training.seed)
        model = model_class(conf,gmn_config).to(conf.training.device)
    
            
        optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=conf.training.learning_rate, 
            weight_decay=conf.training.weight_decay # do we need it for such few epochs?
            )
        
        ## 
        init_embedding = get_embeds(model,dataset)
        init_state_dict = model.state_dict()
        
        embeds = [init_embedding]
        state_dicts = [init_state_dict]
        status["losses"][model_num] = []
        run = 0
        
        
        
        while run < conf.expt.num_epochs:
            model.train()
            
            optimizer.zero_grad()
            prediction = model(graphs,graph_node_sizes,graph_edge_sizes)
            
            # logger.info(labels.shape,prediction.shape)
            predictions_for_positives = prediction[labels > 0.5]
            predictions_for_negatives = prediction[labels < 0.5]
            losses = pairwise_ranking_loss(
                predictions_for_positives.unsqueeze(1),
                predictions_for_negatives.unsqueeze(1),
                conf.training.margin
            )
            losses.backward()
            optimizer.step()
            status["losses"][model_num].append(losses.item())
            
            if run % every == 0:
                embeds.append(get_embeds(model,dataset))
                state_dicts.append(model.state_dict())
            
            run += 1
        
        # save embeddings
        os.makedirs(f"{conf.write_dir}/data/{conf.task.name}",exist_ok=True)
        torch.save(embeds,f"{conf.write_dir}/data/{conf.task.name}/embeddings_{model_num}.pt")
        pickle.dump(state_dicts,open(f"{conf.write_dir}/data/{conf.task.name}/state_dicts_{model_num}.pkl","wb"))
        status["model"] +=1
        pickle.dump(status,open(f"{conf.write_dir}/data/{conf.task.name}/checkpoint_info.pkl","wb"))
            
 
## PLOTTING FUNCTIONS
def make_plots(conf, embeddings: torch.Tensor, project: bool = True, variant: str | None = None):
    latexify()
    if project:
        w = torch.randn((1, 1, embeddings.shape[1]), device=conf.training.device)
        w /= torch.linalg.norm(w)
        embeddings = (w @ embeddings).squeeze()

    os.makedirs(f"{conf.write_dir}/plots/{conf.task.name}/mpl", exist_ok=True)
    embs_by_dim = embeddings.T.cpu().numpy()

    title = variant or f"{conf.dataset.name} {conf.model.name} embeddings"
    out_prefix = f"{conf.write_dir}/plots/{conf.task.name}/mpl/{conf.title}"

    # --- Boxplot ---
    if conf.plt.boxplot:
        plt.figure()
        sns.boxplot(data=[emb for emb in embs_by_dim], orient="v")
        plt.xlabel("Dimension")
        plt.ylabel("Embedding value")
        plt.legend(prop={'weight': 'bold'}, fontsize=22, loc="upper right", bbox_to_anchor=(1, 1))
        plt.savefig(f"{out_prefix}-boxplot.pdf")
        plt.close()

    # --- Histograms + KDE ---
    if conf.plt.hist:
        plt.figure(figsize=(7.2, 4.8))
        xticks = (-1., 0., 1.)
        limits = (-1.5, 1.5)
        num_bins = 25
        bin_edges = np.linspace(limits[0], limits[1], num_bins + 1)
        palette = sns.color_palette("husl", len(embs_by_dim))
        for i, emb in enumerate(embs_by_dim):
            vals = emb
            if np.allclose(vals, 0):
                logger.info(f"Skipping histogram for dim {i} (all zeros)")
                continue
            color = palette[i % len(palette)]
            sns.histplot(
            vals, bins=bin_edges, stat='density', kde=True, color=color, alpha=0.2,
            edgecolor=color, line_kws={"linewidth": 3, "label": fr"$d={i+1}$"}
            )
        plt.xticks(xticks, fontsize=32)
        plt.yticks(rotation=45, fontsize=30)
        plt.xlabel(r"$\pmb{X}^{(c)} [v,d] \longrightarrow$")
        plt.ylabel(r"\textbf{Density} $\longrightarrow$")
        plt.legend(loc="upper left", fontsize=24, bbox_to_anchor=(0.6, 1.0), prop={'weight': 'bold'})
        plt.xlim(-1.6, 1.6)
        plt.tight_layout(pad=0.1)
        plt.savefig(f"{out_prefix}-histogram.pdf")
        plt.close()

    # --- Histograms + KDE for selected dimensions (0, 4, 9) ---
    if conf.plt.minihist:
        plt.figure(figsize=(7.2,4.8))
        xticks = (-1.,0.,1.)
        limits = (-1.5,1.5)
        num_bins = 25
        bin_edges = np.linspace(limits[0], limits[1], num_bins + 1)
        colors = ("green","red","blue")
        for j,i in enumerate((0, 4, 9)):
            emb = embs_by_dim[i]
            vals = emb
            if np.allclose(vals, 0):
                logger.info(f"Skipping histogram for dim {i} (all zeros)")
                continue
            color = colors[j]
            sns.histplot(vals, bins=bin_edges , stat='density', kde=True, color=color, alpha=0.2, edgecolor=color, line_kws={"linewidth": 3,"label": fr"$d={i+1}$"})
        plt.xticks(xticks,fontsize=32)
        plt.yticks(rotation=45, fontsize=30)
        # plt.title(f"{title} - Selected Dimensions Distribution")
        plt.xlabel(r"$\pmb{X}^{(c)} [v,d] \longrightarrow$")
        plt.ylabel(r"\textbf{Density} $\longrightarrow$")
        plt.legend(loc="upper left", fontsize=24, bbox_to_anchor=(0.6, 1.0), prop={'weight': 'bold'})
        plt.xlim(-1.6, 1.6)
        plt.tight_layout(pad=0.1)
        plt.savefig(f"{out_prefix}-mini-histogram.pdf")
        plt.close()

    # --- ECDF ---
    if conf.plt.ecdf:
        plt.figure(figsize=(10, 6))
        for i, emb in enumerate(embs_by_dim):
            x = np.sort(emb)
            y = np.arange(len(x)) / float(len(x))
            plt.plot(x, y, label=f"dim {i}")
        plt.title(f"{title} - Empirical CDF")
        plt.xlabel("Embedding value")
        plt.ylabel("ECDF")
        plt.legend(loc="lower right", fontsize=20, ncol=2)
        plt.tight_layout()
        plt.legend(prop={'weight': 'bold'})
        plt.savefig(f"{out_prefix}-ecdf.pdf")
        plt.close()

    logger.info(f"Written plots to {conf.write_dir}/plots/{conf.task.name}/{conf.title}")
 

### Just picks the first graph for plotting!! ###
def load_data(folder,conf):
    dirname = os.path.join(f"{conf.write_dir}/data", folder)
    obj = []
    for i in tqdm(range(10000)):
        try:
            emb = torch.stack(torch.load(os.path.join(dirname, f"embeddings_{i}.pt")))
            emb = emb[:,:1,:,:] if len(emb.shape) == 4 else emb[:,:1,:]
            obj.append(emb)
        except FileNotFoundError:
            logger.info(f'Did not find {os.path.join(dirname, f"embeddings_{i}.pt")}')
    logger.info(f"Loaded {len(obj)} tensors from {dirname}")
    return torch.concatenate(obj,dim=1)



if __name__=="__main__": 
    
    expt_config = OmegaConf.load(f"configs/distn.yaml")
    
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    
    
    try:
        logger.info(model_conf.model.non_exchbl is None)
    except:
        model_conf = OmegaConf.merge(model_conf, OmegaConf.create({"model": {"non_exchbl": None}}))
        
    
    
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, expt_config, cli_conf)

    name_removal_set = {'classPath', 'name'}
    task_name = ",".join("=".format(i, k ) for i, k  in conf.model.items() if (i not in name_removal_set))

    
    if conf.dataset.rel_mode == "sub_iso":
        conf.task.name = f"distn_{conf.model.name}_{conf.dataset.rel_mode}_{conf.dataset.name}_{task_name}_stemp={conf.training.sinkhorn_temp}_margin={conf.training.margin}"
    elif conf.dataset.rel_mode == "ged":
        conf.task.name = f"distn_{conf.model.name}_{conf.dataset.rel_mode}_{conf.dataset.name}_{task_name}_stemp={conf.training.sinkhorn_temp}_margin={conf.training.margin}"
    elif conf.dataset.rel_mode == "uneq_ged":
        conf.task.name = f"distn_Uneq_{conf.model.name}_{conf.dataset.name}_numC={conf.dataset.aug_num_cgraphs}_MaxSkew={conf.dataset.MaxSkew}_{task_name},stemp={conf.training.sinkhorn_temp},margin={conf.training.margin}"
    else:
        raise ValueError("Unknown rel_mode")
    logger.info(f"Task name: {conf.task.name}")
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    
    every = 4
    if every not in (1,None):
        conf.task.name = f"{conf.task.name}_every{every}"
    

    gmn_config = modify_gmn_main_config(get_default_gmn_config(conf), conf, logger)
    
    if conf.train: # train
        train_model_distn_expt(conf,gmn_config,every=every)
    # conf.task.name = f"{conf.task.name}_strat2"
    
    if conf.plot: # make plots
        nanl_exp = load_data(conf.task.name,conf)
        epochs, _, nodes, __ = nanl_exp.shape
        for i in tqdm(range(epochs)):
            for j in (0,): ## only plot the first node
                conf.title = f"{conf.task.name}_epoch{i*every}_node{j}"
                variant = f"$\\text{{Distribution of }}\\mathbf{{X}}({j})[\\cdot]:\\ \\text{{epoch}}\\ {i * every}$"
                latexify()
                make_plots(conf,nanl_exp[i,:,j,:],variant=variant,project=False)
    
