import sys
import os
BASE_PATH = os.environ.get("BASE_PATH", "")
if BASE_PATH and BASE_PATH.endswith('/'):
    BASE_PATH = BASE_PATH[:-1]
sys.path.append(BASE_PATH)
#rom scripts.notebooks.true_loss_level.get_transition_probabilities import load_model_corelogic
from hydra import compose, initialize
import numpy as np
import pickle
import pandas as pd
from adjustText import adjust_text
import json
def get_config():

    rand_train2_top4 = {
         "experiment": "equities/attention_factors_equities",
        "checkpoint_path": f"{BASE_PATH}/outputs/outputs/2025-04-12/10-43-38/last.ckpt",
        "data_path": f"{BASE_PATH}/data/corelogic/loan_data_60000_unemployment_curr_balance_cir_spi.npz"
    }

    return rand_train2_top4


def hydraload_corelogic(config,path,dataset):
    from train import SequenceLightningModule
    import src.utils as utils

    #from train import preemption_setup
    config = utils.train.process_config(config)
    config.train.pretrained_model_path = path
    utils.train.print_config(config, resolve=True)

    #config = preemption_setup(config)
    config.dataset.data_path = dataset
    
    model = SequenceLightningModule(config)
    model = SequenceLightningModule.load_from_checkpoint(
            config.train.pretrained_model_path,
            config=config,
            strict=config.train.pretrained_model_strict_load,
        )
    return model

def load_model_corelogic(experiment,checkpoint_path,data_path, **kwargs):
    try:
        initialize(version_base=None, config_path="./../../../configs/")
    except:
        print("Already initialized")
    cfg = compose(config_name="config.yaml",
                overrides=["experiment="+experiment])
    #cfg.dataset.load_data = True # Why true
    #cfg.dataset.dataset_config.database_size = 1000
    model = hydraload_corelogic(cfg, checkpoint_path, data_path)
    return model

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def plot_betas_tsne(
    model, 
    company_names, 
    market_caps, 
    sectors, 
    save_path="plot/betas_tsne.pdf", 
    assets_to_plot=80  # Only top 80
):
    # Extract the betas array (shape: (N, 30))
    betas = model.model.layers[0].layer.attention_factors.beta.data.detach().cpu().numpy()

    # Run t-SNE on betas
    tsne = TSNE(n_components=2, random_state=42)
    betas_2d = tsne.fit_transform(betas)  # shape: (N, 2)

    # Identify the unique sectors and build a color mapping
    # We'll assume there are 11 unique sectors.
    sector_options = sorted(list(set(sectors)))  # e.g., 11 unique items
    cmap = plt.cm.get_cmap('tab20', len(sector_options))
    sector_to_color = {sector: cmap(i) for i, sector in enumerate(sector_options)}

    # Manually override the color for sector at sorted index 9, e.g. "Utilities"
    # (the 10th sector in sector_options).
    if len(sector_options) > 9:
        sector_to_color[sector_options[9]] = "yellow"
    if len(sector_options) > 3:
        sector_to_color[sector_options[3]] = "lightcoral"
    
    # Exclude specific indices, then pick top 'assets_to_plot' from sorted market_caps
    indices_to_exclude = [405, 461, 443, 480, 491, 498, 494, 479]
    top_indices = [len(market_caps) - i - 1 for i in range(assets_to_plot + len(indices_to_exclude))]
    top_indices = [i for i in top_indices if i not in indices_to_exclude]

    # Identify "other" indices (non-top, everything else)
    all_indices = set(range(len(betas_2d)))
    other_indices = all_indices - set(top_indices)

    # Clean the company names for just the top indices
    cleaned_names = []
    suffixes_to_remove = [
        " S A B DE C V", " INC", " LTD", " CORP", " CO", " INC.", " LTD.",
        " CORPORATION", " COMPANY", " SA NV", " LLC", " LP", " PLC", " GROUP",
        "GRP", " A S A", " PUB", " NEW", " S A", " L P", " SE", " U S", "A G",
        " HOLDINGS", " HLDGS", " HLDG",
    ]
    for i in top_indices:
        name = company_names[i]
        for suffix in suffixes_to_remove:
            name = name.replace(suffix, "")
        name = name.title().strip()
        cleaned_names.append(name)
    
    # Merge consecutive single characters (e.g. "S A P" -> "SAP")
    for i in range(len(cleaned_names)):
        words = cleaned_names[i].split()
        j = 0
        while j < len(words) - 1:
            if len(words[j]) == 1 and len(words[j + 1]) == 1:
                merged = words[j]
                k = j + 1
                while k < len(words) and len(words[k]) == 1:
                    merged += words[k]
                    k += 1
                words[j:k] = [merged]
            else:
                j += 1
        cleaned_names[i] = " ".join(words)

    # Create a new figure
    plt.figure(figsize=(20, 10))
    points_x = []
    points_y = []
    # 1) Plot the non-labeled points (not in top_indices)
    for idx in other_indices:
        x, y = betas_2d[idx]
        sec_color = sector_to_color[sectors[idx]]
        plt.scatter(x, y, s=25, alpha=0.7, color=sec_color, edgecolors='none')
        points_x.append(x)
        points_y.append(y)

    # 2) Plot the top 80 points, larger, with labels
    
    labels   = []

    for i, idx in enumerate(top_indices):
        x, y = betas_2d[idx]
        sec_color = sector_to_color[sectors[idx]]
        
        # Scatter each point: size=25, alpha=0.8
        plt.scatter(x, y, s=25, alpha=0.7, color=sec_color, edgecolors='none')
        
        # Place text label
        t = plt.text(x, y, cleaned_names[i], fontsize=12)
        labels.append(t)
        points_x.append(x)
        points_y.append(y)

    # Use adjust_text for avoiding text overlap
    adjust_text(
        labels,
        x=points_x,
        y=points_y,
        arrowprops=dict(arrowstyle='->', color='gray', lw=0.5, alpha=0.7),
        only_move={'points':'', 'texts':'xy'},
        expand_points=(2.0, 2.0),
        expand_text=(1.2, 1.2),
        force_points=1.4,
        force_text=3.0,
        lim=5000
    )

    plt.title("t-SNE of Learned Betas to 30 Attention Factors", fontsize=24)
    plt.xlabel("Dimension 1", fontsize=17)
    plt.ylabel("Dimension 2", fontsize=17)

    # 3) Create dummy scatters for the legend (one for each sector)
    for sector in sector_options:
        plt.scatter([], [], color=sector_to_color[sector], label=sector)

    # Place legend below the plot
    plt.legend(
        loc='upper center', 
        bbox_to_anchor=(0.5, -0.05),
        fancybox=True, 
        shadow=False,
        fontsize=15.5,
        ncol=6
    )

    # Save the figure
    plt.savefig(save_path, bbox_inches="tight", format="pdf")
    plt.close()

def load_company_names(path_dataset, path_names):
    save_path = path_dataset

    with np.load(save_path, allow_pickle=True) as data:
        splits = pickle.loads(data['splits'].item())
        stats = pickle.loads(data['stats'].item())
    print(f"Dataset loaded from {save_path}")
    #print(splits.keys())
    
    permnos = stats["valid_permnos"]
    market_caps = stats["market_caps"]
    df = pd.read_parquet(path_names)
    
    # Create a mapping from permnos to company names
    permno_to_name = df[['permno', 'comnam']].drop_duplicates().set_index('permno')['comnam'].to_dict()
    
    # Get company names for the valid permnos
    company_names = [permno_to_name.get(permno, f"Unknown ({permno})") for permno in permnos]

    # Get the sector classification
    with open("sector_classification.json", "r") as f:
        sector_classification = json.load(f)
    sectors = []
    for i in range(len(permnos)):
        permno = permnos[i]
        sector = sector_classification[str(permno)]["sector"][-1]
        sectors.append(sector)
    
    
    return company_names, market_caps, sectors

def plot_model_kernel(model):
    
    kernel = model.model.layers[0].layer.kernel.kernel.data.detach().cpu().numpy()[0,:,:]
    # shape (d_model, seq_len)
    weights = model.decoder[0].linear.weight.detach().cpu().numpy()[0,:]
    joint_kernel = np.einsum("dl,d->l",kernel,weights)

    #joint_kernel = 
    #bias = model.decoder[0].linear.bias.detach().cpu().numpy()
    breakpoint()
    #for i in range(5):
    #    plt.plot(joint_kernel[i,:], label=f"Conv {i}")
    plt.plot(joint_kernel)
    #plt.legend()
    plt.title("Kernel of the first layer of the model")
    plt.xlabel("Lookback window")
    plt.ylabel("Kernel value")
    plt.savefig("plot/kernel.pdf", bbox_inches="tight", format="pdf")
    # shape (d_model, seq_len)

def main():

    config = get_config()
    path_dataset = f"{BASE_PATH}/data/equities/equity_dataset_2021.npz"
    path_names = f"{BASE_PATH}/scripts/notebooks/equities/stocknames_full.parquet"
    company_names, market_caps, sectors = load_company_names(path_dataset, path_names)
    model = load_model_corelogic(**config)
    #print(model)
    assets_to_plot = 75
    save_path = f"{BASE_PATH}/scripts/notebooks/equities/plot/betas_tsne_2021_{assets_to_plot}.pdf"
    plot_model_kernel(model)
    #plot_betas_tsne(model, company_names, market_caps, sectors, save_path, assets_to_plot=assets_to_plot)
    
    



if __name__ == "__main__":
    main()