# -*- coding: utf-8 -*-
"""TRAJECTORY_FULL_RUN_FINAL_SUBMIT_DOMAIN_AWARE_PRE_TRAINED (1).ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1obUoy99x86S4MGZ0b8qRyOQjg8vFkLot
"""

!pip install transformers datasets tokenizers peft accelerate bitsandbytes -q



import numpy as np

companies = ['MSFT', 'NVDA', 'AAPL', 'AMZN', 'META', 'AVGO', 'TSLA', 'GOOGL', 'BRK.B', 'GOOG', 'JPM', 'V', 'LLY', 'NFLX', 'XOM', 'MA', 'COST', 'WMT', 'PG', 'HD', 'JNJ', 'ABBV', 'BAC', 'UNH', 'CRM', 'KO', 'PLTR', 'ORCL', 'PM', 'WFC', 'CSCO', 'GE', 'IBM', 'CVX', 'ABT', 'MCD', 'LIN', 'NOW', 'DIS', 'ISRG', 'ACN', 'GS', 'AMD', 'T', 'UBER', 'MRK', 'INTU', 'VZ', 'PEP', 'RTX', 'ADBE', 'BKNG', 'TXN', 'QCOM', 'CAT', 'AXP', 'PGR', 'MS', 'SPGI', 'TMO', 'BA', 'BSX', 'SCHW', 'NEE', 'TJX', 'AMAT', 'C', 'HON', 'AMGN', 'BLK', 'UNP', 'SYK', 'CMCSA', 'ETN', 'LOW', 'PANW', 'DE', 'ADP', 'PFE', 'GILD', 'DHR', 'GEV', 'COP', 'TMUS', 'ADI', 'MMC', 'LRCX', 'BX', 'VRTX', 'MDT', 'FI', 'CRWD', 'KLAC', 'MU', 'CB', 'APH', 'ANET', 'PLD', 'ICE', 'SBUX', 'CME', 'AMT', 'MO', 'TT', 'LMT', 'INTC', 'SO', 'CEG', 'BMY', 'CDNS', 'WELL', 'DUK', 'KKR', 'ELV', 'PH', 'MCK', 'AJG', 'EQIX', 'CI', 'MDLZ', 'SHW', 'WM', 'MMM', 'SNPS', 'TDG', 'AON', 'ORLY', 'CVS', 'COF', 'MCO', 'CTAS', 'UPS', 'NKE', 'PYPL', 'CL', 'WMB', 'CMG', 'PNC', 'MSI', 'ZTS', 'USB', 'GD', 'EMR', 'DASH', 'HCA', 'FTNT', 'ITW', 'EOG', 'HWM', 'APO', 'JCI', 'ADSK', 'BK', 'ECL', 'MAR', 'RCL', 'NOC', 'AZO', 'HLT', 'ROP', 'APD', 'REGN', 'CSX', 'TRV', 'ABNB', 'CARR', 'WDAY', 'FCX', 'NEM', 'CPRT', 'NSC', 'TFC', 'OKE', 'NXPI', 'ALL', 'KMI', 'AXON', 'VST', 'AEP', 'DLR', 'FICO', 'MPC', 'PSX', 'AFL', 'FDX', 'PWR', 'SLB', 'DFS', 'AMP', 'GM', 'ROST', 'PCAR', 'SPG', 'BDX', 'PAYX', 'AIG', 'RSG', 'COR', 'TEL', 'O', 'GWW', 'SRE', 'PSA', 'URI', 'CTVA', 'MET', 'FAST', 'CMI', 'D', 'EW', 'KVUE', 'KDP', 'KMB', 'MSCI', 'KR', 'TGT', 'MNST', 'CCI', 'VRSK', 'VLO', 'EXC', 'IDXX', 'AME', 'F', 'LHX', 'FIS', 'YUM', 'CHTR', 'CTSH', 'XEL', 'PEG', 'CBRE', 'OTIS', 'PRU', 'TTWO', 'BKR', 'HES', 'PCG', 'TRGP', 'RMD', 'HIG', 'GLW', 'CAH', 'LULU', 'VMC', 'MPWR', 'EA', 'WAB', 'SYY', 'ROK', 'DELL', 'DHI', 'ETR', 'ED', 'IT', 'ACGL', 'DXCM', 'EFX', 'EQT', 'NDAQ', 'IR', 'GEHC', 'EBAY', 'MLM', 'VICI', 'MCHP', 'DAL', 'WEC', 'ODFL', 'CSGP', 'A', 'NRG', 'EXR', 'GRMN', 'MTB', 'XYL', 'ANSS', 'WTW', 'OXY', 'CNC', 'GIS', 'STZ', 'AVB', 'IRM', 'DD', 'KEYS', 'STT', 'VTR', 'RJF', 'BR', 'HUM', 'NUE', 'DTE', 'TSCO', 'FANG', 'HPQ', 'TPL', 'IP', 'GDDY', 'FITB', 'AWK', 'UAL', 'PPG', 'BRO', 'AEE', 'DOV', 'LEN', 'CDW', 'FTV', 'PPL', 'VLTO', 'CPAY', 'DRI', 'ATO', 'TYL', 'HSY', 'SBAC', 'CCL', 'SYF', 'IQV', 'EXE', 'CNP', 'KHC', 'ADM', 'EQR', 'HPE', 'HBAN', 'MTD', 'SW', 'TDY', 'CINF', 'CHD', 'SMCI', 'PODD', 'VRSN', 'STE', 'LYV', 'DVN', 'CBOE', 'ES', 'STX', 'K', 'EIX', 'TROW', 'NVR', 'WRB', 'DOW', 'WSM', 'FE', 'AMCR', 'NTRS', 'EXPE', 'HUBB', 'FSLR', 'PHM', 'PTC', 'GPN', 'WBD', 'CMS', 'WAT', 'RF', 'LH', 'NTAP', 'LDOS', 'DECK', 'DG', 'DGX', 'IFF', 'INVH', 'ULTA', 'ON', 'ZBH', 'LII', 'STLD', 'WY', 'LUV', 'MKC', 'MAA', 'HAL', 'JBL', 'CTRA', 'CFG', 'ESS', 'NI', 'BIIB', 'FDS', 'DLTR', 'TRMB', 'MOH', 'GPC', 'TPR', 'PKG', 'SNA', 'PFG', 'WDC', 'DPZ', 'KEY', 'CLX', 'FFIV', 'PNR', 'EXPD', 'COO', 'APTV', 'BALL', 'LNT', 'GEN', 'TSN', 'BAX', 'ROL', 'J', 'L', 'ZBRA', 'LYB', 'EL', 'WST', 'CF', 'OMC', 'EVRG', 'EG', 'LVS', 'AVY', 'BBY', 'IEX', 'KIM', 'MAS', 'BLDR', 'TER', 'TXT', 'ALGN', 'JKHY', 'HOLX', 'UDR', 'CPT', 'ALLE', 'PAYC', 'JNPR', 'FOXA', 'DOC', 'REG', 'JBHT', 'SJM', 'POOL', 'AKAM', 'SWKS', 'CHRW', 'SWK', 'RVTY', 'UHS', 'BG', 'ARE', 'NDSN', 'LKQ', 'HST', 'RL', 'TKO', 'NWSA', 'CAG', 'MOS', 'KMX', 'EPAM', 'VTRS', 'AIZ', 'PNW', 'GL', 'SOLV', 'INCY', 'BXP', 'TAP', 'EMN', 'DAY', 'IPG', 'ERIE', 'AES', 'HII', 'HSIC', 'WYNN', 'NCLH', 'HAS', 'HRL', 'MRNA', 'AOS', 'WBA', 'MKTX', 'MGM', 'GNRC', 'TECH', 'MTCH', 'LW', 'FRT', 'ALB', 'CRL', 'PARA', 'IVZ', 'BEN', 'CPB', 'APA', 'FOX', 'CZR', 'ENPH', 'BF.B']
instructions=['ESG:','Company:','ESG First Order:','ESG Second Order:','ESG Moving Average:','Returns:','Returns First Order:','Returns Second Order:','Sentiment:']

moving_avg_token_map = {}
first_order_token_map = {}
second_order_token_map = {}
esg_token_map = {}
company_for_append=[]
data_set_tokenized_lines = []

e_token_map = {}
s_token_map = {}
g_token_map = {}

import ast

# Load the file content
with open('PaperReady_e_scores.txt', 'r') as file:
    raw_text = file.read()

# Convert string to Python list of dicts
e_data_full_list = ast.literal_eval(raw_text)

# Now `esg_array` is a Python list of dictionaries
print(type(e_data_full_list))         # should be <class 'list'>
print(e_data_full_list[0]['text'])    # shows first record

import ast

# Load the file content
with open('PaperReady_s_scores.txt', 'r') as file:
    raw_text = file.read()

# Convert string to Python list of dicts
s_data_full_list = ast.literal_eval(raw_text)

# Now `esg_array` is a Python list of dictionaries
print(type(s_data_full_list))         # should be <class 'list'>
print(s_data_full_list[0]['text'])    # shows first record

import ast

# Load the file content
with open('PaperReady_g_scores.txt', 'r') as file:
    raw_text = file.read()

# Convert string to Python list of dicts
g_data_full_list = ast.literal_eval(raw_text)

# Now `esg_array` is a Python list of dictionaries
print(type(g_data_full_list))         # should be <class 'list'>
print(g_data_full_list[0]['text'])    # shows first record

import ast

# Load the file content
with open('esg_risk_ratings_1.txt', 'r') as file:
    raw_text = file.read()

# Convert string to Python list of dicts
esg_data_full_list = ast.literal_eval(raw_text)

# Now `esg_array` is a Python list of dictionaries
print(type(esg_data_full_list))         # should be <class 'list'>
print(esg_data_full_list[0]['text'])    # shows first record

import re
def create_train_test_split(data_list, stype,test_size=4):
    data_new_list = []
    actual_dict = {}

    for data in data_list:
        # Extract ESG values
        values = re.findall(r'\d+\.\d+', data['text'])
        data_new_values = values[:-test_size]
        last_values = values[-test_size:]  # Extract the last `test_size` values

        # Convert last_values to floats
        last_values = list(map(float, last_values))

        # Extract company name
        company_match = re.search(r'Company:\s*([A-Z]+)', data['text'])
        if company_match:
            company_name = company_match.group(1)
        else:
            raise ValueError("Company name not found in data['text']")

        # Reconstruct the data dictionary with trimmed ESG values
        trimmed_text = f"Company: {company_name} {stype}: " + " ".join(data_new_values)
        data_new_list.append({'text': trimmed_text})

        # Create the actual dictionary with float values
        actual_dict[company_name] = last_values

    return data_new_list, actual_dict

data_new_list, actual_dict = create_train_test_split(esg_data_full_list,"ESG" ,test_size=20)
#print("Data New List:", data_new_list)
#print("Actual Dictionary:", actual_dict)

data_new_list_e, actual_dict_e = create_train_test_split(e_data_full_list,"ENV", test_size=20)
#print("Data New List:", data_new_list_e)
#print("Actual Dictionary:", actual_dict_e)

data_new_list_s, actual_dict_s = create_train_test_split(s_data_full_list,"SOC", test_size=20)
#print("Data New List:", data_new_list_s)
#print("Actual Dictionary:", actual_dict_s)

data_new_list_g, actual_dict_g = create_train_test_split(g_data_full_list,"GOV", test_size=20)
#print("Data New List:", data_new_list_g)
#print("Actual Dictionary:", actual_dict_g)

import numpy as np

def longest_consistency_streak(values):
    max_streak = 1
    current_streak = 1
    for i in range(1, len(values)):
        if values[i] == values[i-1]:
            current_streak += 1
            max_streak = max(max_streak, current_streak)
        else:
            current_streak = 1
    return max_streak

# Kernel extraction functions
def compute_kernels(esg_values,epsilon=0.4):
    esg_values = np.array(esg_values)
    streak = longest_consistency_streak(esg_values)
    if streak >= 10:
        credit = 0.4
    elif streak >= 5:
        credit = 0.25
    else:
        credit = 0.1
    print("esg")
    print(esg_values)
    first_order = np.diff(esg_values)

    # Add small positive noise where diff == 0.0
    #noise = np.random.uniform(0.01, epsilon, size=first_order.shape)+credit
    #first_order = np.where(first_order == 0.0, noise, first_order)
    first_order = np.round(first_order, 2)

    print("first_order")
    print(first_order)

    second_order = np.diff(first_order)

    #noise2 = np.random.uniform(0.01, epsilon, size=second_order.shape)+credit
    #second_order = np.where(second_order == 0.0, noise2, second_order)
    second_order = np.round(second_order, 2)

    print("second_order")
    print(second_order)

    return first_order, second_order


def compute_moving_average_tokens(values, window,kind):
    tokens = []
    for i in range(len(values) - window + 1):
        window_vals = values[i:i+window]
        ma = sum(window_vals) / window
        tokens.append(get_SSNT_format(ma, "MA"))
    return tokens

#def moving_average_to_token(ma_value, window,kind):
 #   return f"<{kind}_{format_diff_token(ma_value)}>"

def format_diff_token(value, kind="FO"):
    sign = "-" if value < 0 else ""
    value = abs(value)
    integer_part = int(value)
    decimal_part = int(round((value - integer_part) * 100))
    formatted = f"{sign}{integer_part:02}.{decimal_part:02}"
    return f"{formatted}"


def format_diff_token_int(value, kind="FO"):
    sign = "-" if value < 0 else ""
    value = abs(value)
    integer_part = int(value)
  #  decimal_part = int(round((value - integer_part) * 100))
    formatted = f"{sign}{integer_part:02}"
    return f"{formatted}"

def get_SSNT_format(numeric_value, kind):
    return f"<{kind}_{format_diff_token(numeric_value)}>"

def get_SSNT_format_RETURN(numeric_value, kind):
    return f"<{kind}_{format_diff_token_int(numeric_value)}>"

esg_core={}
esg_fo={}
esg_so={}
e_core={}
s_core={}
g_core={}
senti_core={}
ret_core={}

data_new_list_e

for record in data_new_list_e:
    text = record['text'].replace('<EOS>', '')

    company = text.split('Company: ')[1].split(' ENV:')[0].strip() # company name
    company_for_append.append(company)
    e_values = text.split('ENV:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    #esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']

    #prompt1=text

    e_values = [float(v) for v in e_values]


    e_tokens = [f"Company: {company} ENV:"] # {format_diff_token(esg_values[0])}
    e_tokens += [get_SSNT_format(v, kind=f"ENV") for v in e_values]
    e_core[company] = " ".join([get_SSNT_format(v, kind=f"ENV") for v in e_values])
    #esg_tokens.append("<EOS>")
    e_token_map[company] = e_tokens

for record in data_new_list_s:
    text = record['text'].replace('<EOS>', '')

    company = text.split('Company: ')[1].split(' SOC:')[0].strip() # company name
    company_for_append.append(company)
    s_values = text.split('SOC:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    #esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']

    #prompt1=text

    s_values = [float(v) for v in s_values]


    s_tokens = [f"Company: {company} SOC:"] # {format_diff_token(esg_values[0])}
    s_tokens += [get_SSNT_format(v, kind=f"SOC") for v in s_values]
    s_core[company] = " ".join([get_SSNT_format(v, kind=f"SOC") for v in s_values])
    #esg_tokens.append("<EOS>")
    s_token_map[company] = s_tokens

for record in data_new_list_g:
    text = record['text'].replace('<EOS>', '')

    company = text.split('Company: ')[1].split(' GOV:')[0].strip() # company name
    company_for_append.append(company)
    g_values = text.split('GOV:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    #esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']

    #prompt1=text

    g_values = [float(v) for v in g_values]


    g_tokens = [f"Company: {company} GOV:"] # {format_diff_token(esg_values[0])}
    g_tokens += [get_SSNT_format(v, kind=f"GOV") for v in g_values]
    g_core[company] = " ".join([get_SSNT_format(v, kind=f"GOV") for v in g_values])
    #esg_tokens.append("<EOS>")
    g_token_map[company] = g_tokens

for record in data_new_list_e:
    text = record['text'].replace('<EOS>', '')

    company = text.split('Company: ')[1].split(' ENV:')[0].strip() # company name
    company_for_append.append(company)
    e_values = text.split('ENV:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    #esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']

    #prompt1=text

    e_values = [float(v) for v in e_values]


    e_tokens = [f"Company: {company} ENV:"] # {format_diff_token(esg_values[0])}
    e_tokens += [get_SSNT_format(v, kind=f"ENV") for v in e_values]
    e_core[company] = " ".join([get_SSNT_format(v, kind=f"ENV") for v in e_values])
    #esg_tokens.append("<EOS>")
    e_token_map[company] = e_tokens

for record in data_new_list:
    text = record['text'].replace('<EOS>', '')

    company = text.split('Company: ')[1].split(' ESG:')[0].strip() # company name
    company_for_append.append(company)
    esg_values = text.split('ESG:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    #esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']

    #prompt1=text

    esg_values = [float(v) for v in esg_values]

    fo, so = compute_kernels(esg_values)
    print(fo)
    print(so)

    esg_tokens = [f"Company: {company} ESG:"] # {format_diff_token(esg_values[0])}
    esg_tokens += [get_SSNT_format(v, kind=f"ESG") for v in esg_values]
    esg_core[company] = " ".join([get_SSNT_format(v, kind=f"ESG") for v in esg_values])
    #esg_tokens.append("<EOS>")
    esg_token_map[company] = esg_tokens
    #fo_tokens.append("<EOS>")



    fo_tokens = [f"Company: {company} ESG First Order:"] # {format_diff_token(esg_values[0])}
    fo_tokens += [get_SSNT_format(v, kind=f"ESGFO") for v in fo]
    esg_fo[company] = " ".join([get_SSNT_format(v, kind=f"ESGFO") for v in fo])
    first_order_token_map[company] = fo_tokens
    #fo_tokens.append("<EOS>")

    so_tokens = [f"Company: {company} ESG Second Order:"] # {format_diff_token(esg_values[0])}
    so_tokens += [get_SSNT_format(v,kind=f"ESGSO") for v in so]
    esg_so[company] = " ".join([get_SSNT_format(v, kind=f"ESGSO") for v in so])
    second_order_token_map[company] = so_tokens
    #so_tokens.append("<EOS>")

    #moving_avg_tokens = [f"Company: {company} ESG Moving Average:"] # {format_diff_token(esg_values[0])}
    #moving_avg_tokens += compute_moving_average_tokens(esg_values,5,company)
    #moving_avg_token_map[company] = moving_avg_tokens


    #text+=" <EOS>"
    #data_set_tokenized_lines.append("".join(text))
    data_set_tokenized_lines.append(" ".join(esg_tokens))

    data_set_tokenized_lines.append(" ".join(e_tokens))
    data_set_tokenized_lines.append(" ".join(s_tokens))
    data_set_tokenized_lines.append(" ".join(g_tokens))

    data_set_tokenized_lines.append(" ".join(fo_tokens))
    data_set_tokenized_lines.append(" ".join(so_tokens))
    #data_set_tokenized_lines.append(" ".join(moving_avg_tokens))

g_core['AAPL']

from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer


#with open("training_data_set_company.txt", "w") as f:
 #   for line in tokenized_lines:
  #      f.write(line + "\n")
fo_tokens = []
so_tokens = []
ma_tokens = []
esg_tokens = []
return_1y_tokens=[]
return_1y_tokens_fo=[]
return_1y_tokens_so=[]

e_tokens=[]
s_tokens=[]
g_tokens=[]
# Token range for FO (first order) and SO (second order) kernels
#fo_tokens = [f"<FO_DIFF_{i}>" for i in range(0, 5001, 100)]  # e.g., 0 to 5000 by 100
#so_tokens = [f"<SO_CURV_{i}>" for i in range(0, 5001, 100)]



#fo_tokens = [f"<FO_DIFF_{a:02}.{b:02}>" for a in range(-100,100) for b in range(100)]
#fo_tokens = [f"<FO_DIFF_{i}>" for i in range(-2500, 2501, 100)]  # -25.00 to +25.00
for a in range(0, 100):
          for b in range(100):
              sign = "-" if a < 0 or (a == 0 and b < 0) else ""
              esg_tokens.append(f"<ESG_{sign}{abs(a):02}.{abs(b):02}>")

for a in range(-99, 100):
          for b in range(100):
              sign = "-" if a < 0 or (a == 0 and b < 0) else ""
              fo_tokens.append(f"<ESGFO_{sign}{abs(a):02}.{abs(b):02}>")
      #-00.
for b in range(100):
              sign = "-"
              fo_tokens.append(f"<ESGFO_{sign}00.{abs(b):02}>")



for a in range(-99, 100):
          for b in range(100):
              sign = "-" if a < 0 or (a == 0 and b < 0) else ""
              so_tokens.append(f"<ESGSO_{sign}{abs(a):02}.{abs(b):02}>")

      #-00.
for b in range(100):
              sign = "-"
              so_tokens.append(f"<ESGSO_{sign}00.{abs(b):02}>")


for a in range(-99, 100):
    sign = "-" if a < 0 else ""
    return_1y_tokens.append(get_SSNT_format_RETURN(a,"RET"))

      #-00.
for b in range(100):
      sign = "-"
      return_1y_tokens.append(get_SSNT_format_RETURN(b,"RET"))



for a in range(0, 100):
          for b in range(100):
              sign = "-" if a < 0 or (a == 0 and b < 0) else ""
              e_tokens.append(f"<ENV_{sign}{abs(a):02}.{abs(b):02}>")


for a in range(0, 100):
          for b in range(100):
              sign = "-" if a < 0 or (a == 0 and b < 0) else ""
              s_tokens.append(f"<SOC_{sign}{abs(a):02}.{abs(b):02}>")


for a in range(0, 100):
          for b in range(100):
              sign = "-" if a < 0 or (a == 0 and b < 0) else ""
              g_tokens.append(f"<GOV_{sign}{abs(a):02}.{abs(b):02}>")




#for a in range(-100, 100):
 #   sign = "-" if a < 0 else ""
  #  return_ytd_tokens.append(f"<RETYTD_{sign}{abs(a):02}>")



#so_tokens = [f"<SO_CURV_{a:02}.{b:02}>" for a in range(-100,100) for b in range(100)]

#so_tokens = [f"<SO_CURV_{i}>" for i in range(-2500, 2501, 100)]

special_tokens = ["<PAD>", "<EOS>", "<UNK>"]
all_tokens =  fo_tokens + so_tokens+esg_tokens+["<SENTI_10>","<SENTI_-10>","<SENTI_00>"] +  return_1y_tokens + e_tokens+s_tokens+g_tokens


# Train Tokenizer
special_tokens = all_tokens + companies + special_tokens+instructions

all_possible_tokens = special_tokens

import numpy as np
import re

import numpy as np
import re

import numpy as np
import re

def generate_blockwise_series_embedding(token: str, dim: int = 768, scale: float = 1.0) -> np.ndarray:
    """
    Embedding with directional changes based on numeric value.
    Ensures cosine similarity decreases as numeric difference increases.
    """
    # Parse
    match = re.match(r"^<([A-Z]+)_(-?\d+(?:\.\d+)?)>$", token)
    if not match:
        raise ValueError(f"Invalid token: {token}")
    series_prefix, numeric_value = match.groups()
    numeric_value = float(numeric_value)

    # Series mapping
    series_keys = ["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"]
    block_size = dim // len(series_keys)
    start_idx = series_keys.index(series_prefix) * block_size

    # Base template
    x = np.linspace(0, 1, block_size)
    block = np.sin((x + numeric_value / 100) * np.pi) * scale  # phase-shifted sine

    # Add small quadratic term for non-linearity
    block += (x ** 2) * (numeric_value / 100) * 0.5

    # Normalize block
    block = (block - block.mean()) / (block.std() + 1e-8)

    # Place in full embedding
    pe = np.zeros(dim)
    pe[start_idx:start_idx + block_size] = block
    return pe



from sklearn.decomposition import PCA
from itertools import product
import numpy as np

# === Step 1: Define Token Set ===
# You already have: all_tokens = ["<ESG_70.0>", "<RET_10.0>", ...]
# We use this as-is.

##best_config = optimize_series_base_indices(all_tokens)

import torch
from transformers import GPT2Tokenizer, GPT2Model
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

import numpy as np
import re







import ast

# Step 1: Read the file
with open('weekly_sentiments.txt', 'r') as file:
    content = file.read()

# Step 2: Convert the string to a dictionary safely
raw_sentiments = ast.literal_eval(content)

# Step 3: (Optional) Print or inspect
print(raw_sentiments)

import random

sentiment_choices = ['<SENTI_POSITIVE>', '<SENTI_NEUTRAL>', '<SENTI_NEGATIVE>']
# Mapping for replacement
sentiment_map = {
    '<SENTI_POSITIVE>': '<SENTI_10>',
    '<SENTI_NEUTRAL>': '<SENTI_00>',
    '<SENTI_NEGATIVE>': '<SENTI_-10>',
}

sentiment_series = {}
#raw_sentiments = {}


for company in companies:
  try:
      transformed_series = [sentiment_map[token] for token in raw_sentiments[company]]
      sentiment_series[company] = transformed_series
  except:
    print("skipped for sentiment",company)

senti_core = sentiment_series

return_series = {}

#for company in companies:
 #   # Generate 12 random returns between -30% and +30% with 1 decimal place
  #  return_series[company] = [round(random.randint(-30, 30), 1) for _ in range(12)] # Original Data

import ast

def parse_and_scale_file(filename):
    data_dict = {}
    with open(filename, 'r') as file:
        for line in file:
            if ':' in line:
                key, value = line.strip().split(':', 1)
                values = ast.literal_eval(value.strip())
                scaled_values = [round(x * 100) for x in values]
                data_dict[key.strip()] = scaled_values
    return data_dict

# Usage
filename = 'monthly_returns.txt'
return_series = parse_and_scale_file(filename)

return_strings = []

for company, returns in return_series.items():
    # Convert numbers to strings and join with spaces
    return_text = " ".join(map(str, returns))
    ret_core[company] = return_text

    return_strings.append({
        'text': f"Company: {company} Returns: {return_text}"
    })

# Convert to the requested format
sentiment_strings = []

for company, sentiments in sentiment_series.items():
    # Join all sentiment values with spaces (no week numbers or tuples)
    sentiment_text = " ".join(sentiments)
    senti_core[company] = sentiment_text

    # Format as requested
    sentiment_strings.append({
        'text': f"Company: {company} Sentiment: {sentiment_text}"
    })

# Print the result
#for item in sentiment_strings:
    #print(item)

fo_r_tokens = []
so_r_tokens = []
esg_tokens = []
return_1y_tokens=[]
senti_values=[]

for record in sentiment_strings:
    text = record['text'].replace('<EOS>', '')
    #print("Direct from Data")
    #print(text)
    company = text.split('Company: ')[1].split(' Sentiment:')[0].strip() # company name
    company_for_append.append(company)
    senti_values = text.split('Sentiment:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    senti_tokens = [f"Company: {company} Sentiment:"] # {format_diff_token(esg_values[0])}
    senti_tokens += [v for v in senti_values]
    data_set_tokenized_lines.append(" ".join(senti_tokens))

for record in return_strings:
    text = record['text'].replace('<EOS>', '')
    #print("Direct from Data")
   # print(text)
    company = text.split('Company: ')[1].split(' Returns:')[0].strip() # company name
    company_for_append.append(company)
    r_values = text.split('Returns:')[1].replace('<EOS>', '').strip().split() # ['80.00', '78.00', '79.00', '81.00']
    #esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']

    #prompt1=text

    r_values = [float(v) for v in r_values]

    #fo, so = compute_kernels(r_values)
    #print(fo)
    #print(so)

    r_tokens = [f"Company: {company} Returns:"] # {format_diff_token(esg_values[0])}
    r_tokens += [get_SSNT_format_RETURN(v, kind=f"RET") for v in r_values]
    ret_core[company] = " ".join([get_SSNT_format_RETURN(v, kind=f"RET") for v in r_values])

    #esg_token_map[company] = esg_tokens


    #so_tokens.append("<EOS>")

    #moving_avg_tokens = [f"Company: {company} ESG Moving Average:"] # {format_diff_token(esg_values[0])}
    #moving_avg_tokens += compute_moving_average_tokens(esg_values,5,company)
    #moving_avg_token_map[company] = moving_avg_tokens


    #text+=" <EOS>"
    #data_set_tokenized_lines.append("".join(text))
    data_set_tokenized_lines.append(" ".join(r_tokens))

## === Step 2: Create Token Embeddings ===
token_types = ["ESG", "RET", "ESGFO", "ESGSO", "RET", "ENV","SOC","GOV","SENTI"]
#values = [50, 60, 70, 80, 90]
tokens = []
embeddings = []
token_embedding_dict = {}
embd_1 = None
for token in all_tokens:
      embd_1 = generate_blockwise_series_embedding(token)
      embeddings.append(embd_1)
      token_embedding_dict[token]= embd_1
  #if token in ['<SENTI_POSITIVE>','<SENTI_NEUTRAL>','<SENTI_NEGATIVE>']:
   #   embeddings.append(generate_sentiment_embedding(token))
    #  token_embedding_dict[token] = generate_sentiment_embedding(token)
#for t_type in token_types:
 #   for v in values:
  #      tok = f"<{t_type}_{v:.2f}>"
   #     tokens.append(tok)

def print_series_block_map(series_keys, dim=768):
    block_size = dim // len(series_keys)
    for i, key in enumerate(series_keys):
        start = i * block_size
        end = start + block_size - 1
        print(f"{key:6} → positions {start:3} to {end:3}")

print_series_block_map(["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"])

from sklearn.preprocessing import StandardScaler

all_embeddings = np.array(list(token_embedding_dict.values()))
all_embeddings = StandardScaler().fit_transform(all_embeddings)







grouped_tokens = {
    'ESG': [],
    'ESGFO': [],
    'ESGSO': [],
    'RET': [],
    'ENV': [],
    'SOC': [],
    'GOV': [],
    'SENTI': []
}

# Group embeddings and tokens
for token, embedding in token_embedding_dict.items():
    for prefix in grouped_tokens:
        if token.startswith(f"<{prefix}_"):
            grouped_tokens[prefix].append((token, embedding))
            break

# Reduce embeddings to 2D using PCA
all_embeddings = [embedding for group in grouped_tokens.values() for _, embedding in group]
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(all_embeddings)

# Plot
plt.figure(figsize=(12, 8))

colors = ['red', 'blue', 'green', 'orange', 'purple', 'yellow', 'brown','pink']
prefixes = list(grouped_tokens.keys())
start = 0

for i, prefix in enumerate(prefixes):
    group = grouped_tokens[prefix]
    count = len(group)
    if count == 0:
        continue
    x = reduced_embeddings[start:start + count, 0]
    y = reduced_embeddings[start:start + count, 1]
   # plt.scatter(x, y, label=prefix, color=colors[i])
    start += count

#plt.title("Token Embeddings by Series (PCA Projection)")
#plt.xlabel("PC1")
#plt.ylabel("PC2")
#plt.legend()
#plt.grid(True)
#plt.tight_layout()
#plt.show()

grouped_tokens = {
    'ESG': [],
    'ESGFO': [],
    'ESGSO': [],
    'RET': [],
    'ENV': [],
    'SOC': [],
    'GOV': [],
    'SENTI': []
}

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib.patches import Rectangle

# === Token types and colors for diagonal ===
token_types = ['RET', 'SOC', 'GOV', 'ESG', 'ESGFO', 'ESGSO', 'ENV', 'SENTI']
n = len(token_types)

label_to_color = {
    'ESG': '#f781bf',    # Red
    'ESGFO': '#377eb8',  # Blue
    'ESGSO': '#4daf4a',  # Green
    'RET': '#ff7f00',    # Orange
    'ENV': '#984ea3',    # Purple
    'SOC': '#ffff33',    # Yellow
    'GOV': '#a65628',    # Brown
    'SENTI': '#999999'   # Gray
}

# === Compute similarity matrix (full 768D) ===
heatmap_matrix = np.full((n, n), np.nan)
for i, t1 in enumerate(token_types):
    emb1 = [v for k, v in token_embedding_dict.items() if k.startswith(f"<{t1}_")]
    emb1 = np.array(emb1)

    for j, t2 in enumerate(token_types):
        emb2 = [v for k, v in token_embedding_dict.items() if k.startswith(f"<{t2}_")]
        emb2 = np.array(emb2)

        if len(emb1) > 0 and len(emb2) > 0:
            sim_matrix = cosine_similarity(emb1, emb2)
            heatmap_matrix[i, j] = sim_matrix.mean()

# === Plot ===
fig, ax = plt.subplots(figsize=(5, 5))

# Base white heatmap (off-diagonal)
sns.heatmap(
    np.zeros_like(heatmap_matrix),
    xticklabels=token_types,
    yticklabels=token_types,
    cmap=["white"],
    cbar=False,
    square=True,
    linewidths=0.5,
    linecolor='black',
    ax=ax
)

# Overlay diagonal coloring and custom annotations
for i in range(n):
    for j in range(n):
        val = heatmap_matrix[i, j]
        if np.isnan(val):
            val = 0.0

        if i == j:
            # Diagonal: colored block + float format
            color = label_to_color[token_types[i]]
            rect = Rectangle((j, i), 1, 1, fill=True, edgecolor='black', facecolor=color, lw=1.5)
            ax.add_patch(rect)
            text = f"{val:.2f}"
            fontsize = 14
            text_color = 'black'  # Force black for diagonal
        else:
            # Off-diagonal: white + 0
            text = f"{int(val)}"
            fontsize = 12
            text_color = 'black'

        ax.text(j + 0.5, i + 0.5, text,
                ha='center', va='center',
                fontsize=fontsize, fontweight='bold',
                color=text_color)

# Tick styling
ax.set_xticklabels(token_types, fontsize=12, fontweight='bold', rotation=45)
ax.set_yticklabels(token_types, fontsize=12, fontweight='bold', rotation=0)

# Force full border visibility
for spine in ax.spines.values():
    spine.set_visible(True)

plt.title("Token Type Cosine Similarity (768D)", fontsize=14, weight='bold')
plt.tight_layout()
plt.show()



"""*SVM"""

import numpy as np

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Compute cosine similarity between two vectors"""
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return dot_product / (norm1 * norm2 + 1e-10)  # add epsilon to avoid division by zero

import numpy as np
import matplotlib.pyplot as plt

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return float(dot_product / (norm1 * norm2 + 1e-10))

def format_diff_token(value: float, is_integer: bool = False) -> str:
    sign = "-" if value < 0 else ""
    value = abs(value)
    return f"{sign}{int(value):02d}" if is_integer else f"{sign}{int(value):02d}.{int(round((value - int(value)) * 100)):02d}"

def get_SSNT_format(value, kind):
    int_token_types = {"RET",  "SENTI"}
    is_int = kind in int_token_types
    return f"<{kind}_{format_diff_token(value, is_integer=is_int)}>"

def plot_similarity_decay_all(token_embedding_dict,token_configs):
    plt.rcParams.update({'font.size': 12})
    fig, ax = plt.subplots(figsize=(5, 5))

    marker_cycle = ['o', 's', 'D', '^', 'v', 'X', '*']
    color_cycle = plt.cm.tab10.colors  # use matplotlib color cycle


    base_tokens = {}

    for i, cfg in enumerate(token_configs):
        kind = cfg["kind"]
        base = cfg["base"]
        is_int = cfg["int"]
        is_cat = cfg.get("categorical", False)

        marker = marker_cycle[i % len(marker_cycle)]
        color = color_cycle[i % len(color_cycle)]
        x_offset = (i - 3) * 0.15  # shift lines slightly along X-axis

        base_token = get_SSNT_format(base, kind)
        base_tokens[kind] = base_token

        if base_token not in token_embedding_dict:
            print(f"Base token {base_token} not found. Skipping.")
            continue

        base_vec = token_embedding_dict[base_token]

        if is_cat:
            steps = cfg["steps"]
            x_labels = ["Negative", "Neutral", "Positive"]
            x_vals = [j + x_offset for j in range(len(steps))]
            y_vals = []
            for val in steps:
                tok = get_SSNT_format(val, kind)
                if tok in token_embedding_dict:
                    sim = cosine_similarity(base_vec, token_embedding_dict[tok])
                    y_vals.append(sim)
            if y_vals:
                ax.plot(x_vals, y_vals, marker=marker, label=kind, linewidth=5.0, color=color)
                ax.set_xticks([0, 1, 2])
                ax.set_xticklabels(["-10", "00", "+10"])
        else:
            steps = [base + j * cfg["step"] for j in range(-cfg["range"], cfg["range"] + 1)]
            x_vals = []
            y_vals = []
            for val in steps:
                tok = get_SSNT_format(val, kind)
                if tok in token_embedding_dict:
                    sim = cosine_similarity(base_vec, token_embedding_dict[tok])
                    distance = abs(val - base) #Removed abs
                    x_vals.append(distance + x_offset)
                    y_vals.append(sim)
            if x_vals:
                ax.plot(x_vals, y_vals, marker=marker, label=kind, linewidth=2.5, color=color)

    ax.set_ylim(0.4, 1.0)
    ax.set_xlim(0, 40)
    ax.set_xticks(np.linspace(0, 40, 15))  # 15 evenly spaced ticks

    ax.set_xlabel("Numeric Distance from Base Token", fontsize=12)
    ax.set_ylabel("Cosine Similarity", fontsize=12)
    ax.set_title("Similarity vs Numeric Distance", fontsize=12)
    ax.tick_params(axis='x', labelrotation=45)
    ax.grid(True, linestyle='--', alpha=1)
    ax.legend(title="Token Type", fontsize=12, loc='best')
    plt.tight_layout()
    plt.show()

    print("\n=== Base Tokens Used ===")
    for kind, tok in base_tokens.items():
        print(f"{kind:<6} → {tok}")

import numpy as np
import matplotlib.pyplot as plt

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return float(dot_product / (norm1 * norm2 + 1e-10))

def format_diff_token(value: float, is_integer: bool = False) -> str:
    sign = "-" if value < 0 else ""
    value = abs(value)
    return f"{sign}{int(value):02d}" if is_integer else f"{sign}{int(value):02d}.{int(round((value - int(value)) * 100)):02d}"

def get_SSNT_format(value, kind):
    int_token_types = {"RET",  "SENTI"}
    is_int = kind in int_token_types
    return f"<{kind}_{format_diff_token(value, is_integer=is_int)}>"

def plot_similarity_decay_all_new(token_embedding_dict, token_configs):
    plt.rcParams.update({'font.size': 12})
    fig, ax = plt.subplots(figsize=(6, 5))

    marker_cycle = ['o', 's', 'D', '^', 'v', 'X', '*', 'P']
    # Use a bright, gaudy color palette
    color_cycle = ['magenta', 'cyan', 'lime', 'red', 'orange', 'blue', 'purple', 'gold']

    base_tokens = {}

    for i, cfg in enumerate(token_configs):
        kind = cfg["kind"]
        base = cfg["base"]
        is_int = cfg["int"]
        is_cat = cfg.get("categorical", False)

        marker = marker_cycle[i % len(marker_cycle)]
        color = color_cycle[i % len(color_cycle)]
        x_offset = (i - 3) * 0.15  # shift lines slightly along X-axis

        base_token = get_SSNT_format(base, kind)
        base_tokens[kind] = base_token

        if base_token not in token_embedding_dict:
            print(f"Base token {base_token} not found. Skipping.")
            continue

        base_vec = token_embedding_dict[base_token]

        if is_cat:
            steps = cfg["steps"]
            x_labels = ["Negative", "Neutral", "Positive"]
            x_vals = [j + x_offset for j in range(len(steps))]
            y_vals = []
            for val in steps:
                tok = get_SSNT_format(val, kind)
                if tok in token_embedding_dict:
                    sim = cosine_similarity(base_vec, token_embedding_dict[tok])
                    y_vals.append(sim)
            if y_vals:
                ax.plot(x_vals, y_vals, marker=marker, label=kind, linewidth=5.0, color=color)
                ax.set_xticks([0, 1, 2])
                ax.set_xticklabels(["-10", "00", "+10"])
        else:
            steps = [base + j * cfg["step"] for j in range(-cfg["range"], cfg["range"] + 1)]
            x_vals = []
            y_vals = []
            for val in steps:
                tok = get_SSNT_format(val, kind)
                if tok in token_embedding_dict:
                    sim = cosine_similarity(base_vec, token_embedding_dict[tok])
                    distance = abs(val - base)
                    x_vals.append(distance + x_offset)
                    y_vals.append(sim)
            if x_vals:
                ax.plot(x_vals, y_vals, marker=marker, label=kind, linewidth=2.5, color=color)

    ax.set_ylim(0.8, 1.0)
    ax.set_xlim(0, 14)
    ax.set_xlabel("Numeric Distance from Base Token", fontsize=12)
    ax.set_ylabel("Cosine Similarity", fontsize=12)
    ax.set_title("Similarity vs Numeric Distance", fontsize=12,fontweight='bold')
    ax.tick_params(axis='x', labelrotation=45)
    ax.grid(True, linestyle='--', alpha=1)
    ax.legend(title="Token Type", fontsize=14, loc='best')
    plt.tight_layout()
    plt.show()

    print("\n=== Base Tokens Used ===")
    for kind, tok in base_tokens.items():
        print(f"{kind:<6} → {tok}")



import pandas

import numpy as np

def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Compute cosine similarity between two vectors"""
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    return dot_product / (norm1 * norm2 + 1e-10)  # add epsilon to avoid division by zero

"""Trajectory Level"""



data_set_tokenized_samples = []
with open("training_data_set_company_kernel.txt", "w") as f:
    for entry in data_set_tokenized_lines:
        #print(entry)
        # Extract company and prices from the text string
        #text = entry['text']
        data_set_tokenized_samples.append(entry)
        #print(text)
        # Remove <EOS> tag if present and strip whitespace
        #text = text.replace("<EOS>", "").strip()

        # Write line to file
        f.write(entry + "\n")

from transformers import GPT2TokenizerFast

tokenizer_extended = GPT2TokenizerFast(
    vocab_file="vocab.json",
    merges_file="merges.txt")

from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer

tokenizer_extended.add_tokens(all_possible_tokens, special_tokens=True)

tokenizer_extended.pad_token = "<PAD>"
tokenizer_extended.eos_token = "<EOS>"
tokenizer_extended.unk_token = "<UNK>"

from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
import torch

# Load tokenizer
#tokenizer = PreTrainedTokenizerFast(tokenizer_file="kernel_tokenizer.json")
#tokenizer.add_special_tokens({'pad_token': '<PAD>', 'eos_token': '<EOS>', 'unk_token': '<UNK>'})
print(len(tokenizer_extended))
# Load model and resize embedding
model = GPT2LMHeadModel.from_pretrained("gpt2")  # or your fine-tuned base
model.resize_token_embeddings(len(tokenizer_extended))


from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig, TaskType

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["c_attn", "c_proj"],  # depends on the model architecture
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Inject LoRA into the base model
model = get_peft_model(model, lora_config)


model.config.pad_token_id = tokenizer_extended.pad_token_id  # needed for loss masking

with torch.no_grad():
     for token, embedding in token_embedding_dict.items():
        idx = tokenizer_extended.convert_tokens_to_ids(token)
        model.transformer.wte.weight[idx] = torch.tensor(embedding)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#inputs = tokenizer(prompt, return_tensors="pt")
#inputs = {k: v.to(device) for k, v in inputs.items()}
model = model.to(device)

from torch.utils.data import Dataset
import torch

class PricePredictionDataset(Dataset):
    def __init__(self, samples, tokenizer):
        self.samples = samples
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text = self.samples[idx]
        encoded = self.tokenizer(
            text,
            truncation=True,
            return_tensors='pt',
            add_special_tokens=True
        )

        input_ids = encoded['input_ids'].squeeze(0)  # remove batch dim

        # Labels are same as input_ids but used for loss computation
        return {
            'input_ids': input_ids,
            'labels': input_ids.clone()  # label = next-token prediction
        }

dataset = PricePredictionDataset(data_set_tokenized_samples, tokenizer_extended)

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=lambda x: {
        'input_ids': torch.nn.utils.rnn.pad_sequence(
            [e['input_ids'] for e in x],
            batch_first=True,
            padding_value=tokenizer_extended.pad_token_id
        ),
        'labels': torch.nn.utils.rnn.pad_sequence(
            [e['labels'] for e in x],
            batch_first=True,
            padding_value=-100  # mask pad tokens from contributing to loss
        )
    }
)

dataset = PricePredictionDataset(data_set_tokenized_samples, tokenizer_extended)

# This will show max token ID across your dataset
def find_max_token_id(dataset):
    max_id = 0
    for example in dataset:
        # Check if 'input_ids' is not empty before finding the max
        print(example)
        if example and 'input_ids' in example:
            max_id = max(max_id, max(example['input_ids']))
    print("Max token ID in dataset:", max_id)

find_max_token_id(dataset)

print("Model vocab size:", model.get_input_embeddings().weight.size(0))

assert tokenizer_extended.pad_token_id < model.config.vocab_size

def check_for_out_of_range_ids(dataset, vocab_size):
    for i, example in enumerate(dataset):
        for token_id in example['labels']:
            if token_id >= vocab_size:
                print(f"Out-of-range token ID {token_id} in sample {i}")

!unzip 'NORM_LOSS_FORECAST_50.zip' -d /NORM_LOSS_FORECAST_50/

#****************content/esg_finetuned_gpt2_v2_0/
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast
import torch
from peft import PeftModel, PeftConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load tokenizer
#tokenizer = PreTrainedTokenizerFast(tokenizer_file="ESG_2_tokenizer.json")

from transformers import PreTrainedTokenizerFast

tokenizer_extended = PreTrainedTokenizerFast.from_pretrained("/NORM_LOSS_FORECAST_50/content/NORM_LOSS_FORECAST_50")

tokenizer_extended.add_special_tokens({'pad_token': '<PAD>', 'eos_token': '<EOS>', 'unk_token': '<UNK>'})
print(len(tokenizer_extended))
base_model = GPT2LMHeadModel.from_pretrained("gpt2")
base_model.resize_token_embeddings(len(tokenizer_extended))    # IMPORTANT!

# Step 3: Load LoRA adapters on resized model
model = PeftModel.from_pretrained(base_model, "/NORM_LOSS_FORECAST_50/content/NORM_LOSS_FORECAST_50")




#model = model.merge_and_unload()
model.config.pad_token_id = tokenizer_extended.pad_token_id  # needed for loss masking
# Step 4: Move to device and set eval mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def get_zero_order_input(text:str):
    print(text)
    company = text.split('Company: ')[1].split(' ESG:')[0].strip() # company name
    esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']
    prompt1=text


    #prompt1=text

    esg_values = [float(v) for v in esg_values]


    esg_tokens = [f"Company: {company} ESG:"] # {format_diff_token(esg_values[0])}
    esg_tokens += [get_SSNT_format(v,"ESG") for v in esg_values]
    #fo_tokens.append("<EOS>")

    esg_prompt = " ".join(esg_tokens)
    print("**zero Order**")
    print(esg_prompt)
    return esg_prompt

def get_first_order_input(text:str):
    print(text)
    company = text.split('Company: ')[1].split(' ESG:')[0].strip() # company name
    esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']
    prompt1=text

    esg_values = [float(v) for v in esg_values]

    fo, so = compute_kernels(esg_values)
    print(fo)
    print(so)
    fo_tokens = [f"Company: {company} ESG First Order:"] # {format_diff_token(esg_values[0])}
    fo_tokens += [get_SSNT_format(v,"ESGFO") for v in fo]
    #fo_tokens.append("<EOS>")

    prompt2 = " ".join(fo_tokens)
    print("**First Order**")
    print(prompt2)
    return prompt2

def get_second_order_input(text:str):
    print(text)
    company = text.split('Company: ')[1].split(' ESG:')[0].strip() # company name
    esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']
    prompt1=text

    esg_values = [float(v) for v in esg_values]

    fo, so = compute_kernels(esg_values)
    so_tokens = [f"Company: {company} ESG Second Order:"] # {format_diff_token(esg_values[0])}
    so_tokens += [get_SSNT_format(v,"ESGSO") for v in so]
    #so_tokens.append("<EOS>")

    prompt3 = " ".join(so_tokens)
    print("**Second Order**")
    print(prompt3)
    return prompt3

def get_moving_average_input(text:str):
    print(text)
    company = text.split('Company: ')[1].split(' ESG:')[0].strip() # company name
    esg_values = text.split('ESG:')[1].strip().split() # ['80.00', '78.00', '79.00', '81.00']
    prompt1=text

    esg_values = [float(v) for v in esg_values]

    moving_avg= compute_moving_average_tokens(esg_values,5,company)
    ma_start = [f"Company: {company} ESG Moving Average:"] # {format_diff_token(esg_values[0])}
    ma_start += moving_avg
    #so_tokens += [curvature_to_token(v) for v in so]
    #so_tokens.append("<EOS>")

    prompt4 = " ".join(ma_start)
    print("**Moving Averages**")
    print(prompt4)
    return prompt4

predicted_float_no=0.0

def predict_from_prompt_wrapper(model, tokenizer, prompt, device='cpu', max_new_tokens=12):
            def predict_from_prompt(model, tokenizer, inputs, device='cpu', max_new_tokens=12):
                      global predicted_float_no

                      from transformers import LogitsProcessorList, MinLengthLogitsProcessor, TopKLogitsWarper
                      from transformers.generation.logits_process import LogitsProcessor

                     # class RestrictTokensProcessor(LogitsProcessor):
                      #    def __init__(self, allowed_ids):
                       #       self.allowed_ids = set(allowed_ids)

                        #  def __call__(self, input_ids, scores):
                         #     mask = torch.full_like(scores, float('-inf'))
                          #    for idx in self.allowed_ids:
                           #       mask[:, idx] = scores[:, idx]
                            #  return mask

                     # logits_processor = LogitsProcessorList([
                      #    RestrictTokensProcessor(allowed_token_ids)
                      #])

                      model.eval()
                      output = model.generate(
                      input_ids=inputs["input_ids"],
                      max_new_tokens=max_new_tokens,
                      do_sample=False,
                      temperature=0.3,
                       top_k=1,                        # restrict to most probable
                          top_p=1.0  ,                     # no nucleus filtering
                      eos_token_id=tokenizer.convert_tokens_to_ids("<EOS>"),
                      pad_token_id=tokenizer.pad_token_id  #,
                        #  logits_processor=logits_processor

                          )
                      #print("&&& OUTPUT&&&")
                     # print(output)
                     # print("&&& OUTPUT[0]&&&")
                      #print(output[0])

                      # Decode output and extract predicted float
                      decoded = tokenizer.decode(output[0], skip_special_tokens=False)
                      # Extract the float
                      import re
                      predicted_text = decoded[len(prompt):].strip()
                      #print("&&&&PREDICTED TEXT&&&&&")
                      print(predicted_text)
                      match = re.search(r"[-+]?[0-9]*\.?[0-9]+", predicted_text)
                      predicted_float = float(match.group()) if match else None
                      predicted_float_no = predicted_float
                      #print("Predicted float:", predicted_float)


                      predicted_value = decoded.replace(prompt_text, "").strip()
                     # print("&&&&&VAL&&&&")
                      #print(predicted_value)
                     # print("output of model as is", output[0])
                     # print(tokenizer.decode(output[0], skip_special_tokens=False))
                     # print("CORE MODEL OUTPUT(PREDICTION)")
                      #print(output)

                      #print("tokenizer.convert_ids_to_tokens(output[0])")
                     # print(tokenizer.convert_ids_to_tokens(output[0]))

                      #generated_ids = outputs[0][input_ids.shape[-1]:]
                      #predicted_text = tokenizer.decode(generated_ids, clean_up_tokenization_spaces=True).strip()

                      return tokenizer.convert_ids_to_tokens(output[0])

            #inputs = tokenizer(prompt, return_tensors="pt").to(device)
           # print("&&&&& INPUTS&&&&")
            print(prompt)
            #instruction = "Complete the series and return the next value as a single float:\n Return only a float. No explanation"
            #series_text = prompt+"\nAnswer:"

            # Combine instruction and series
            #prompt_text = instruction + series_text
            prompt_text=prompt
            inputs = tokenizer_extended(prompt_text, return_tensors="pt").to(device)
            #print("*********************************\n")
           # print("TOKENIZER:\n")
            #print(inputs)
            output = predict_from_prompt(model,tokenizer,inputs,device,max_new_tokens)
           # print("OUTPUT:\n")
           # print(output)
           # print("*********************************\n")
            return output

import re




def compute_next_esg_NO_MA(esg_tokens, fo_tokens, so_tokens,ma_tokens) -> float:
    """
    Computes ESG_t = ESG_{t-1} + FO_{t-1} + SO_t using token arrays.
    Handles invalid or missing values gracefully.
    """
    def parse_token_value(token: str, expected_prefix: str) -> float:
        """
        Extracts the numeric value from a token like <FO_DIFF_-02.00> or <SO_CURV_04.00>
        Only returns value if the prefix matches, otherwise returns 0.0
        """
        if not token.startswith(expected_prefix):
            return 0.0

        match = re.search(r'[-+]?\d*\.\d+|\d+', token)
        if match:
            value = float(match.group())
            if '-' in token and not token.startswith(f"{expected_prefix}0"):
                return -abs(value)
            return value
        return 0.0



    def get_last_non_negative_float(tokens):
        """
        Extracts and returns the last non-negative float value from the token list.
        Returns 0.0 if no valid value is found.
        """
        numeric_values = []
        for tok in tokens:
          try:
              val = float(tok)
              if val >= 0:
                  numeric_values.append(val)
          except:
              continue
        return numeric_values[-1] if numeric_values else 0.0

    esg_values = [parse_token_value(tok, f"<ESG_") for tok in esg_tokens if tok.startswith(f"<ESG_")]
    esg_t_minus_1 = esg_values[-2] if esg_values else 0.0

    # Use last non-negative ESG score as ESG_{t-1}
    #esg_t_minus_1 = get_last_non_negative_float(esg_tokens)

      # Extract last FO_DIFF value (or 0.0 if none)
    fo_values = [parse_token_value(tok, f"<ESGFO_") for tok in fo_tokens if tok.startswith(f"<ESGFO_")]
    fo_t_minus_1 = fo_values[-2] if fo_values else 0.0

    # Extract last SO_CURV value (or 0.0 if none)
    so_values = [parse_token_value(tok, f"<ESGSO_") for tok in so_tokens if tok.startswith(f"<ESGSO_")]
    so_t = so_values[-1] if so_values else 0.0

   # ma_values = [parse_token_value(tok, f"<MA_") for tok in ma_tokens if tok.startswith(f"<MA_")]
   # ma_t_minus_1 = 0.0
   # try:
   #   ma_t_minus_1 = ma_values[-2] if so_values else 0.0
    #except:
   #   ma_t_minus_1 = 0.0

    print("Ingrieidients")
    print("ESG Numeric" ,esg_t_minus_1)
    print("FO ",fo_t_minus_1)
    print("SO",so_t)
   ## print("MA",ma_t_minus_1)


    # Final ESG score prediction
    esg_t = esg_t_minus_1 + fo_t_minus_1 + so_t #+ 0.1*(esg_t_minus_1 - ma_t_minus_1)
    print("Next ESG Score",esg_t)
    return round(esg_t, 2)

from sklearn.metrics import mean_squared_error
summary_report = {}
skipped_companies=[]
def evaluate_esg_predictions_rolling_window(numeric_companies_esg: dict, window_size: int = 8, target_length: int = 8):
    all_company_results = {}

    for company, values in numeric_companies_esg.items():
        if len(values) < window_size + target_length:
            print(f"⚠️ Skipping {company}: Not enough data (needs at least {window_size + target_length} values).")
            skipped_companies.append(company)
            continue


        start_index= len(values) - target_length - window_size
        if start_index < 0:
            print(f"⚠️ Skipping {company}: Not enough data (needs at least {window_size + target_length} values).")
            skipped_companies.append(company)
            continue
        print("start_index",start_index)
        # Invert values as before (100 - value)
        values_inverted = [v for v in values]

        predictions = []

        predictions_without_ma = []
        print("COMPANY STARTED", company)
        print(f"\n🧪 Predicting for {company} with rolling window...")
        prefix = f"Company: {company} ESG:"

        for i in range(start_index,target_length+start_index):
            # Define rolling training window slice
            train_series = values_inverted[i : i + window_size]
            print(f"********iteration {i} company {company}— training on indices [{i}:{i + window_size}]")
            #prompt = build_prompt(company, train_series)
            print("prompt_FULL for :",company, train_series)
            print(f"********TRAINING DATA {i} company {company}— training on indices [{i}:{i + window_size}]. {train_series}")

           # try:
                # convert train series o prompt using format_diff_token
            esg_string = ' '.join(format_diff_token(x) for x in train_series)
            esg_string = prefix + " " + esg_string

            output_number = predict_from_prompt_wrapper(model,tokenizer_extended,get_zero_order_input(esg_string),device,2)
            predictions.append(predicted_float_no)


            output_FO = predict_from_prompt_wrapper(model,tokenizer_extended,get_first_order_input(esg_string),device,2)
            output_SO = predict_from_prompt_wrapper(model,tokenizer_extended,get_second_order_input(esg_string),device,2)
           # output_MA = predict_from_prompt_wrapper(model,tokenizer_extended,get_moving_average_input(esg_string,"C"),device,1)
           # esg_score= compute_next(output_number,output_FO,output_SO,output_MA)
            esg_score_without_ma = compute_next_esg_NO_MA(output_number,output_FO,output_SO,0)
            print("=====================================================")
            print("Company=",company)
            print("prompts")
            print("prompt orig=",esg_string)
            print("prompt FO=",get_first_order_input(esg_string))
            print("prompt SO=",get_second_order_input(esg_string))
            print(" prompt ESG=",get_zero_order_input(esg_string))
            print("predicted FO=",output_FO)
            print("predicted SO=",output_SO)
            print("predicted ESG=",output_number)
            print("Predicted ESG Score(using FO and MO)",esg_score_without_ma)
            print("=====================================================")


               # predicted_text = run_agentic_esg_prediction1(prompt,company)
                #prediction_value = float( crit. .strip().split()[-1])
                 #print("*text:",prediction_value)
                #prediction_value = float(critiq_revised_score)
                #print("*Critiq recieved Score:",prediction_value)
                #print("*Text:", predicted_text)
                #print("*Actual Value:", train_series[i + window_size])
            #except Exception as e:
             #   print("Fall back for company ",company)
              #  prediction_value = train_series[-1]  # fallback to last known value
               ## print("Exception", e)

            predictions_without_ma.append(esg_score_without_ma)

        # Actual values to compare against
        test_series = values_inverted[start_index+window_size : start_index+window_size + target_length]


        mse=0
        mse_without_ma=0
        #mse = mean_squared_error(test_series, predictions)

        #mse_without_ma = mean_squared_error(test_series, predictions_without_ma)


        all_company_results[company] = {
            "mse": mse,
            "actual": test_series,
            "predicted_using_FO_SO": predictions_without_ma,
            "mse_using_FO_and_MO": mse_without_ma,
            "predicted_using_ESG": predictions
        }
        #print("COMPANY FINISHED", company)

       # print(f"✅ Actual   : {test_series}")
        #print(f"✅ Predicted using ESG: {predictions}")
        #print(f"✅ Predicted using FO MO: {predictions_without_ma}")
       # print(f"✅ MSE using ESG     : {mse:.4f}")
       # print(f"✅ MSE using FO SO     : {mse_without_ma:.4f}")

        #print(f"📉 MSE with MA     : {mse:.4f}")
        #print(f"📉 MSE using FO SO     : {mse_without_ma:.4f}")


    print("\nALL COMPANY FINISHED\n")
    print("📊 Summary: MSE, Actual, and Predicted per Company\n")



    for company, result in all_company_results.items():
        #print(f"🔹 {company}")
        #print(f"   MSE      : {result['mse']:.4f}")
        #print(f"   Actual   : {result['actual']}")
        #print(f"   Predicted(using ESG): {result['predicted_using_ESG']}\n")
        #print(f"   MSE USING FO MO: {result['mse_using_FO_and_MO']}\n")


        rtmse = round(np.sqrt(mse), 4)
        summary_report[company] = {
        'MSE': round(result['mse'], 4),
        'RTMSE': rtmse,
        'Actual': result['actual'],
        'Predicted': result['predicted_using_ESG'],
        'predicted_using_FO_SO': result['predicted_using_FO_SO'],
        'mse_using_FO_and_MO': result['mse_using_FO_and_MO']

    }

    if all_company_results:
        print("Processing done")
    else:
        print("❌ No companies were processed.")



    return all_company_results

def parse_esg_data(data):
    parsed = {}
    for entry in data:
        text = entry['text']
        company_match = re.search(r'Company:\s*(\w+)', text)
        esg_scores = re.findall(r'\d+\.\d+', text)
        company = company_match.group(1) if company_match else 'Unknown'
        parsed[company] = list(map(float, esg_scores))
    return parsed

#all_company_results= evaluate_esg_predictions_rolling_window(parse_esg_data(esg_data_full_list),target_length=20)





summary_report

import json

with open("summary_report_full_submission_pt.txt", "w") as f:
    f.write(json.dumps(summary_report, indent=2))

data = summary_report

summary_report

summary_report

data=summary_report

import pandas as pd

# Show all columns
pd.set_option('display.max_columns', None)

# Show all rows
pd.set_option('display.max_rows', None)

# Disable column width truncation
pd.set_option('display.max_colwidth', None)

"""COmpany Level tokens - Math approach"""

count = len(summary_report)

count

import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score



# Evaluation functions
def mape(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean(np.abs((y_true - y_pred) / (np.abs(y_true) + 1e-10))) * 100

def smape(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return 100 * np.mean(2 * np.abs(y_pred - y_true) / (np.abs(y_pred) + np.abs(y_true) + 1e-10))

def directional_accuracy(y_true, y_pred):
    return np.mean(np.sign(np.diff(y_true)) == np.sign(np.diff(y_pred)))

# DataFrame to store metrics
df_summary = {}
overall = {
    "MAE_Pred": 0, "MSE_Pred": 0,  "MAPE_Pred": 0, "SMAPE_Pred": 0, "R2_Pred": 0, "DA_Pred": 0
}
count = len(summary_report)

for company, data in summary_report.items():
    actual = data['Actual']
    pred = data['Predicted']
    pred_fo = data['predicted_using_FO_SO']

    metrics = {
        "MAE_Pred": mean_absolute_error(actual, pred),
        "MSE_Pred": mean_squared_error(actual, pred),
      # "RMSE_Pred": mean_squared_error(actual, pred),
        "MAPE_Pred": mape(actual, pred),
        "SMAPE_Pred": smape(actual, pred),
        "R2_Pred": r2_score(actual, pred),
        "DA_Pred": directional_accuracy(actual, pred),

        #"MAE_FO": mean_absolute_error(actual, pred_fo),
       # "MSE_FO": mean_squared_error(actual, pred_fo),
     #   "RMSE_FO": mean_squared_error(actual, pred_fo),
        #"MAPE_FO": mape(actual, pred_fo),
        #"SMAPE_FO": smape(actual, pred_fo),
        #"R2_FO": r2_score(actual, pred_fo),
       # "DA_FO": directional_accuracy(actual, pred_fo)
    }

    df_summary[company] = metrics

    for key in overall:
        overall[key] += metrics[key]

# Average across all companies
overall_avg = {k: v / count for k, v in overall.items()}

# Create final DataFrames
df_metrics = pd.DataFrame(df_summary).T
df_overall = pd.DataFrame(overall_avg, index=["Overall_Avg"])

#import ace_tools as tools; tools.display_dataframe_to_user(name="Evaluation Metrics Summary", dataframe=df_metrics)
df_overall

import numpy as np
import pandas as pd
from scipy.stats import pearsonr

# Load the summary report from a literal JSON-like dictionary
from ast import literal_eval



# === Metric Functions ===
def bias(y_true, y_pred):
    return np.mean(np.array(y_pred) - np.array(y_true))

def wape(y_true, y_pred):
    return 100 * np.sum(np.abs(np.array(y_pred) - np.array(y_true))) / (np.sum(np.abs(np.array(y_true))) + 1e-8)

def mase(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    naive = np.roll(y_true, 1)[1:]
    true = y_true[1:]
    pred = y_pred[1:]
    mae_model = np.mean(np.abs(true - pred))
    mae_naive = np.mean(np.abs(true - naive))
    return mae_model / (mae_naive + 1e-8)

def pearson_corr(y_true, y_pred):
    try:
        return pearsonr(y_true, y_pred)[0]
    except Exception:
        return np.nan

def rae(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    mean_true = np.mean(y_true)
    return np.sum(np.abs(y_pred - y_true)) / (np.sum(np.abs(mean_true - y_true)) + 1e-8)

def theils_u(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    num = np.sqrt(np.mean((y_true - y_pred) ** 2))
    denom = np.sqrt(np.mean(y_true ** 2)) + np.sqrt(np.mean(y_pred ** 2))
    return num / (denom + 1e-8)

# === Compute Metrics Per Company ===
metrics = {}

for company, values in summary_report.items():
    actual = values["Actual"]
    pred = values["Predicted"]
    pred_fo_so = values["predicted_using_FO_SO"]

    metrics[company] = {
        # Without FO_SO
        "Bias_Pred": bias(actual, pred),
        "WAPE_Pred": wape(actual, pred),
        "MASE_Pred": mase(actual, pred),
        "Pearson_Pred": pearson_corr(actual, pred),
        "RAE_Pred": rae(actual, pred),
        "TheilsU_Pred": theils_u(actual, pred),

        # With FO_SO
        #"Bias_FO_SO": bias(actual, pred_fo_so),
       ## "WAPE_FO_SO": wape(actual, pred_fo_so),
        #"MASE_FO_SO": mase(actual, pred_fo_so),
       # "Pearson_FO_SO": pearson_corr(actual, pred_fo_so),
       # "RAE_FO_SO": rae(actual, pred_fo_so),
       # "TheilsU_FO_SO": theils_u(actual, pred_fo_so),
    }

df_metrics = pd.DataFrame(metrics).T

# === Compute Overall Metrics ===
overall = {
    col: df_metrics[col].mean()
    for col in df_metrics.columns
}
df_overall_adv = pd.DataFrame([overall], index=["Overall"])

df_overall_adv

import pandas as pd

# Show all columns
pd.set_option('display.max_columns', None)

# Show all rows
pd.set_option('display.max_rows', None)

# Disable column width truncation
pd.set_option('display.max_colwidth', None)

df_metrics

import numpy as np
import pandas as pd
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    median_absolute_error,
    explained_variance_score,
    max_error
)

def smape(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    diff = np.abs(y_true - y_pred)
    return 100 * np.mean(diff / np.maximum(denominator, 1e-8))

def mape(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return 100 * np.mean(np.abs((y_true - y_pred) / np.maximum(np.abs(y_true), 1e-8)))

def directional_accuracy(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean((np.sign(np.diff(y_true)) == np.sign(np.diff(y_pred))).astype(int)) * 100

def rse(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    mean_y = np.mean(y_true)
    return np.sum((y_true - y_pred) ** 2) / np.sum((y_true - mean_y) ** 2)

def nrmse(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.sqrt(mean_squared_error(y_true, y_pred)) / (np.max(y_true) - np.min(y_true) + 1e-8)

# Compute metrics for each company
company_metrics = []

for company, data in summary_report.items():
    actual = data["Actual"]
    pred = data["Predicted"]

    company_metrics.append({
        "Company": company,
        "MAE": mean_absolute_error(actual, pred),
        "MSE": mean_squared_error(actual, pred),
        "MAPE (%)": mape(actual, pred),
        "SMAPE (%)": smape(actual, pred),
        "R2 Score": r2_score(actual, pred),
        "Explained Variance": explained_variance_score(actual, pred),
        "Median AE": median_absolute_error(actual, pred),
        "Max Error": max_error(actual, pred),
        "Directional Accuracy (%)": directional_accuracy(actual, pred),
        "RSE": rse(actual, pred),
        "NRMSE": nrmse(actual, pred),
    })

df_company_metrics = pd.DataFrame(company_metrics)

import pandas as pd

# Show all columns
pd.set_option('display.max_columns', None)

# Show all rows
pd.set_option('display.max_rows', None)

# Disable column width truncation
pd.set_option('display.max_colwidth', None)

df_company_metrics

df_overall_average = df_company_metrics.drop(columns=["Company"]).mean().to_frame().T
df_overall_average.insert(0, "Company", "Overall")

df_overall_average



esg_context_embeddings_from_model = {}
esg_fo_context_embeddings_from_model = {}
esg_so_context_embeddings_from_model ={}


ret_context_embeddings_from_model = {}
e_context_embeddings_from_model = {}
s_context_embeddings_from_model = {}
g_context_embeddings_from_model = {}

senti_context_embeddings_from_model = {}

def get_company_embedding_from_gpt(prompt):
    tokens = tokenizer_extended(prompt, return_tensors="pt")
    outputs = model.transformer(**tokens)
    sequence_embedding = outputs.last_hidden_state.mean(dim=1)  # (batch_size, hidden_size)
    return sequence_embedding.detach().numpy()

def get_company_embedding_from_gpt(text):
    if isinstance(text, list):
        text = " ".join(text)

    #if not text.strip():  # Check if empty
     #   return torch.zeros(model.config.hidden_size).to(model.device)

    inputs = tokenizer_extended(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        last_hidden = outputs.hidden_states[-1]
        embedding = last_hidden.mean(dim=1).squeeze()
    return embedding

for company,tokens in esg_core.items():
  token_list = tokens.split()
  if token_list:  # Only process non-empty lists
        esg_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

   # esg_context_embeddings_from_model[company] = get_company_embedding_from_gpt(tokens.split())

for company, tokens in esg_fo.items():
    token_list = tokens.split()
    if token_list:  # Only process non-empty lists
        esg_fo_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

for company, tokens in esg_so.items():
  token_list = tokens.split()
  if token_list:  # Only process non-empty lists
      esg_so_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

for company, tokens in ret_core.items():
    token_list = tokens.split()
    if token_list:  # Only process non-empty lists
      ret_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

for company, tokens in e_core.items():
    token_list = tokens.split()
    if token_list:  # Only process non-empty lists
          e_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

for company, tokens in s_core.items():
    token_list = tokens.split()
    if token_list:  # Only process non-empty lists
          s_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

for company, tokens in g_core.items():
    token_list = tokens.split()
    if token_list:  # Only process non-empty lists
          g_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

for company, tokens in senti_core.items():
    token_list = tokens.split()
    if token_list:  # Only process non-empty lists
          senti_context_embeddings_from_model[company] = get_company_embedding_from_gpt(" ".join(token_list))

esg_context_embeddings_from_model_12 = {}
esg_fo_context_embeddings_from_model_12 = {}
esg_so_context_embeddings_from_model_12 ={}


ret_context_embeddings_from_model_12 = {}
e_context_embeddings_from_model_12 = {}
s_context_embeddings_from_model_12 = {}
g_context_embeddings_from_model_12 = {}

senti_context_embeddings_from_model_12 = {}



def get_last_n_tokens(tokens_str, n=12):
    """Extract last n tokens from a space-separated string."""
    token_list = tokens_str.split()
    if len(token_list) < n:
        return None  # Not enough data, skip
    return " ".join(token_list[-n:])  # Take last n

# === Process each series with last 12 tokens ===
for company, tokens in esg_core.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        esg_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in esg_fo.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        esg_fo_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in esg_so.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        esg_so_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in ret_core.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        ret_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in e_core.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        e_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in s_core.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        s_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in g_core.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        g_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

for company, tokens in senti_core.items():
    last_tokens = get_last_n_tokens(tokens, n=12)
    if last_tokens:
        senti_context_embeddings_from_model_12[company] = get_company_embedding_from_gpt(last_tokens)

from sklearn.metrics import pairwise_distances

def get_overlapping_esg_ret_companies(reduced, labels, threshold=0.5):
    overlaps = []
    print("LABELS")
    print(labels)

    n = len(labels)
    for i in range(n):
        for j in range(i + 1, n):
            label_i, label_j = labels[i], labels[j]
            comp_i, cat_i = label_i.split(" (")
            comp_j, cat_j = label_j.split(" (")

            cat_i = cat_i.strip(")")
            cat_j = cat_j.strip(")")

            if cat_i == "ESG" and cat_j == "RET" and comp_i == comp_j:
                dist = np.linalg.norm(reduced[i] - reduced[j])
                if dist < threshold:
                    overlaps.append(comp_i)
    return list(set(overlaps))

import re

def extract_values_from_tokens(tokens, prefix):
    values = []
    for tok in tokens:
        match = re.match(rf"<{prefix}_(\-?\d+\.?\d*)>", tok)
        if match:
            values.append(float(match.group(1)))
    return values

def plot_esg_ret_tokens_for_overlap(overlapping_companies, esg_token_map, ret_token_map):
    for company in overlapping_companies:
        esg_tokens = esg_token_map.get(company, [])
        ret_tokens = ret_token_map.get(company, [])
        if not esg_tokens or not ret_tokens:
            continue

        esg_vals = extract_values_from_tokens(esg_tokens, "ESG")
        ret_vals = extract_values_from_tokens(ret_tokens, "RET")

        min_len = min(len(esg_vals), len(ret_vals))
        esg_vals = esg_vals[:min_len]
        ret_vals = ret_vals[:min_len]

        plt.figure(figsize=(10, 4))
        plt.plot(esg_vals, label="ESG", color="blue", marker='o')
        plt.plot(ret_vals, label="RET", color="red", marker='x')
        plt.title(f"{company} - ESG vs RET Tokens")
        plt.xlabel("Time Step")
        plt.ylabel("Value")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()



import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
import torch

def plot_7_embedding_dicts(
    esg_core_dict,
    esg_fo_dict,
    esg_so_dict,
    ret_core_dict,
     e_dict,
    s_dict,
    g_dict,
    senti_dict
):
    """
    Plots 2D PCA projections of company embeddings from 7 different series dictionaries.
    Each series gets a different color on the same plot.

    Args:
        esg_core_dict, esg_fo_dict, esg_so_dict, ret_core_dict, ret_fo_dict, ret_so_dict, senti_dict:
            Each is a dictionary of the form {company: embedding (torch.Tensor or np.ndarray)}.
    """
    # ===== 1. Convert all embeddings to CPU numpy arrays =====
    def ensure_numpy(embedding):
        if isinstance(embedding, torch.Tensor):
            return embedding.cpu().numpy()
        return embedding  # assume numpy array

    series_names = ["ESG", "ESGFO", "ESGSO", "RET", "ENV", "SOC","GOV", "SENTI"]
    series_dicts = [
        esg_core_dict,
        esg_fo_dict,
        esg_so_dict,
        ret_core_dict,
         e_dict,
          s_dict,
    g_dict,
        senti_dict,
    ]

    all_embeddings = []
    labels = []
    colors = []

    # Use updated colormap API (fixes deprecation warning)
    color_map = plt.colormaps["tab10"]  # or plt.get_cmap("tab10")

    # ===== 2. Process each embedding =====
    for idx, (series_name, series_dict) in enumerate(zip(series_names, series_dicts)):
        for company, embedding in series_dict.items():
            if embedding is None:  # Skip None values
                continue

            numpy_embedding = ensure_numpy(embedding)
            all_embeddings.append(numpy_embedding)
            labels.append(f"{company} ({series_name})")
            colors.append(color_map(idx))

    if not all_embeddings:  # Check if any embeddings exist
        print("Warning: No valid embeddings found to plot!")
        return

    # ===== 3. Apply PCA =====
    all_embeddings = np.stack(all_embeddings)  # Convert list to numpy array
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(all_embeddings)

    # ===== 4. Plot =====
    plt.figure(figsize=(5, 5))
    for idx, series_name in enumerate(series_names):
        # Get indices of points belonging to this series
        points = [i for i, lbl in enumerate(labels) if f"({series_name})" in lbl]
        if not points:  # Skip if no points for this series
            continue

        x = reduced[points, 0]
        y = reduced[points, 1]
        plt.scatter(x, y, label=series_name, alpha=0.8, color=color_map(idx))

    plt.title("PCA Projection of Company Embeddings by Series")
    plt.xlabel("PC1",fontsize=12,fontweight='bold')
    plt.ylabel("PC2",fontsize=12,fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    # Only if you have esg_token_map and ret_token_map available
    #overlapping_companies = get_overlapping_esg_ret_companies(reduced, labels, threshold=0.5)
    #print(overlapping_companies)
    #plot_esg_ret_tokens_for_overlap(overlapping_companies, esg_core, ret_core)

plot_7_embedding_dicts(esg_context_embeddings_from_model_12,esg_fo_context_embeddings_from_model_12,esg_so_context_embeddings_from_model_12,ret_context_embeddings_from_model_12,e_context_embeddings_from_model_12,s_context_embeddings_from_model_12,g_context_embeddings_from_model_12,senti_context_embeddings_from_model_12)

from sklearn.metrics import pairwise_distances

def plot_7_embedding_dicts_with_overlap(
    esg_core_dict,
    esg_fo_dict,
    esg_so_dict,
    ret_core_dict,
    e_dict,
    s_dict,
    g_dict,
    senti_dict,
    overlap_distance_threshold=0.5  # NEW PARAM: what "close" means
):
    def ensure_numpy(embedding):
        if isinstance(embedding, torch.Tensor):
            return embedding.cpu().numpy()
        return embedding

    series_names = ["ESG", "ESGFO", "ESGSO", "RET", "ENV", "SOC","GOV", "SENTI"]
    series_dicts = [
        esg_core_dict,
        esg_fo_dict,
        esg_so_dict,
        ret_core_dict,
        e_dict,
        s_dict,
        g_dict,
        senti_dict,
    ]

    all_embeddings = []
    labels = []
    colors = []

    color_map = plt.colormaps["tab10"]

    for idx, (series_name, series_dict) in enumerate(zip(series_names, series_dicts)):
        for company, embedding in series_dict.items():
            if embedding is None:
                continue
            numpy_embedding = ensure_numpy(embedding)
            all_embeddings.append(numpy_embedding)
            labels.append((company, series_name))
            colors.append(color_map(idx))

    if not all_embeddings:
        print("Warning: No valid embeddings found to plot!")
        return

    all_embeddings = np.stack(all_embeddings)
    pca = PCA(n_components=3)
    reduced = pca.fit_transform(all_embeddings)

    # === NEW: Detect overlaps ===
    print("\n🔍 Possible overlapping or similar points:")
    dist_matrix = pairwise_distances(reduced)
    n = len(labels)
    for i in range(n):
        for j in range(i + 1, n):
            if dist_matrix[i, j] < overlap_distance_threshold:
                (comp1, cat1), (comp2, cat2) = labels[i], labels[j]
                print(f"  {comp1} ({cat1})  ≈  {comp2} ({cat2})  [distance: {dist_matrix[i,j]:.3f}]")

    # === Plotting ===
    plt.figure(figsize=(12, 8))
    for idx, series_name in enumerate(series_names):
        points = [i for i, (_, cat) in enumerate(labels) if cat == series_name]
        if not points:
            continue
        x = reduced[points, 0]
        y = reduced[points, 1]
        plt.scatter(x, y, label=series_name, alpha=0.8, color=color_map(idx))

    plt.title("PCA Projection of Company Embeddings by Series")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

import contextlib

with open("pairwise_distance_esg_12.txt", "w") as f:
    with contextlib.redirect_stdout(f):
      plot_7_embedding_dicts_with_overlap(esg_context_embeddings_from_model_12,
                                          esg_fo_context_embeddings_from_model_12,
                                          esg_so_context_embeddings_from_model_12,
                                          ret_context_embeddings_from_model_12,
                                          e_context_embeddings_from_model_12,
                                          s_context_embeddings_from_model_12,
                                          g_context_embeddings_from_model_12,
                                          senti_context_embeddings_from_model_12)

def plot_blockwise_pca_with_energy_clean(
    esg_core_dict,
    esg_fo_dict,
    esg_so_dict,
    ret_core_dict,
    e_dict,
    s_dict,
    g_dict,
    senti_dict
):
    def ensure_numpy(embedding):
        if isinstance(embedding, torch.Tensor):
            return embedding.cpu().numpy()
        return embedding

    # Block definitions (your allocation)
    blocks = {
        "RET": (0, 96),
        "SOC": (96, 192),
        "GOV": (192, 288),
        "ESG": (288, 384),
        "ESGFO": (384, 480),
        "ESGSO": (480, 576),
        "ENV": (576, 672),
        "SENTI": (672, 768),
    }

    # Merge all dicts
    series_dicts = {
        "ESG": esg_core_dict,
        "ESGFO": esg_fo_dict,
        "ESGSO": esg_so_dict,
        "RET": ret_core_dict,
        "ENV": e_dict,
        "SOC": s_dict,
        "GOV": g_dict,
        "SENTI": senti_dict
    }

    # Build embeddings matrix
    all_embeddings, labels = [], []
    for series_name, series_dict in series_dicts.items():
        for company, embedding in series_dict.items():
            if embedding is None: continue
            all_embeddings.append(ensure_numpy(embedding))
            labels.append((company, series_name))
    all_embeddings = np.stack(all_embeddings)

    # Colors per block
    block_colors = {
        "RET": "blue",
        "SOC": "green",
        "GOV": "orange",
        "ESG": "purple",
        "ESGFO": "pink",
        "ESGSO": "cyan",
        "ENV": "brown",
        "SENTI": "red",
    }

    fig, axes = plt.subplots(len(blocks), 2, figsize=(12, 4 * len(blocks)))
    fig.suptitle("Block-wise PCA Projections & Energy", fontsize=16)

    for i, (block_name, (start, end)) in enumerate(blocks.items()):
        # Filter embeddings for this block only
        block_points = [idx for idx, (_, cat) in enumerate(labels) if cat == block_name]
        if not block_points:
            continue
        block_embeddings = all_embeddings[block_points, start:end]

        # PCA for block
        pca = PCA(n_components=min(10, block_embeddings.shape[1]))
        reduced = pca.fit_transform(block_embeddings)
        explained_variance = pca.explained_variance_ratio_
        cumulative_variance = np.cumsum(explained_variance)

        # === Scatter: only points for this block ===
        ax_scatter = axes[i, 0]
        ax_scatter.scatter(reduced[:, 0], reduced[:, 1], alpha=0.8, color=block_colors[block_name])
        ax_scatter.set_title(f"{block_name} Block PCA")
        ax_scatter.set_xlabel("PC1")
        ax_scatter.set_ylabel("PC2")
        ax_scatter.grid(True)

        # === Energy plot ===
        ax_energy = axes[i, 1]
        ax_energy.bar(range(1, len(explained_variance) + 1), explained_variance, alpha=0.7, color=block_colors[block_name])
       # ax_energy.plot(range(1, len(explained_variance) + 1), cumulative_variance, marker='o', color='black')
        ax_energy.set_title(f"{block_name} - Variance Explained")
        ax_energy.set_xlabel("Principal Component")
        ax_energy.set_ylabel("Variance Explained")
        ax_energy.grid(True)

    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.show()

import contextlib

with open("pairwise_distance_esg.txt", "w") as f:
    with contextlib.redirect_stdout(f):
      plot_blockwise_pca_with_energy_clean(esg_context_embeddings_from_model_12,
                                          esg_fo_context_embeddings_from_model_12,
                                          esg_so_context_embeddings_from_model_12,
                                          ret_context_embeddings_from_model_12,
                                          e_context_embeddings_from_model_12,
                                          s_context_embeddings_from_model_12,
                                          g_context_embeddings_from_model_12,
                                          senti_context_embeddings_from_model_12)

from sklearn.decomposition import PCA

def plot_embedding_distance_map_between_two_companies_blockwise(
    company1, company2,
    esg_dict, esg_fo_dict, esg_so_dict,
    ret_dict, e_dict, s_dict, g_dict, senti_dict
):
    blocks = {
        'ESG': (288, 384),
        'ESGFO': (384, 480),
        'ESGSO': (480, 576),
        'RET': (0, 96),
        'E': (576, 672),
        'S': (96, 192),
        'G': (192, 288),
        'SENTI': (672, 768)
    }
    series_labels = list(blocks.keys())
    series_dicts = [esg_dict, esg_fo_dict, esg_so_dict, ret_dict, e_dict, s_dict, g_dict, senti_dict]

    plt.figure(figsize=(10, 8))
    for (label, (start, end)), d in zip(blocks.items(), series_dicts):
        emb1 = ensure_numpy(d.get(company1))
        emb2 = ensure_numpy(d.get(company2))
        if emb1 is None or emb2 is None:
            continue
        # Normalize
        emb1_norm = emb1 / (np.linalg.norm(emb1) + 1e-10)
        emb2_norm = emb2 / (np.linalg.norm(emb2) + 1e-10)
        # Slice block
        emb1_block = emb1_norm[start:end]
        emb2_block = emb2_norm[start:end]
        # Reduce to 2D (PCA)
        pca = PCA(n_components=2)
        block_proj = pca.fit_transform(np.vstack([emb1_block, emb2_block]))
        vec1, vec2 = block_proj[0], block_proj[1]
        # Distance
        dist = cosine_distances([emb1_block], [emb2_block])[0][0]
        # Plot
        plt.plot([vec1[0], vec2[0]], [vec1[1], vec2[1]], marker='o', label=f"{label} (dist={dist:.3f})")

    plt.title(f"Block-wise Cosine Distance Map: {company1} vs {company2}")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.legend(loc='best')
    plt.grid(True)
    plt.axis('equal')
    plt.tight_layout()
    plt.show()

"""Verify Embedding shapes

Experiment - Cosine Similarity of Company level Embeddings

Experiment - Calculated v/s predicted ESG

> Add blockquote

> Add blockquote
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

def ensure_numpy(embedding):
    if hasattr(embedding, "cpu"):
        return embedding.cpu().numpy()
    return embedding

def plot_similarity_heatmap_between_two_companies_pdf(
    company1,
    company2,
    esg_dict,
    esg_fo_dict,
    esg_so_dict,
    ret_dict,
    e_dict,
    s_dict,
    g_dict,
    senti_dict,
    pdf
):
    # Series configuration
    series_labels = ['ESG', 'ESGFO', 'ESGSO', 'RET', 'ENV', 'SOC', 'GOV', 'SENTI']
    series_dicts = [
        esg_dict,
        esg_fo_dict,
        esg_so_dict,
        ret_dict,
        e_dict,
        s_dict,
        g_dict,
        senti_dict
    ]

    # Compute cosine similarities
    similarities = []
    for d in series_dicts:
        emb1 = ensure_numpy(d.get(company1))
        emb2 = ensure_numpy(d.get(company2))
        if emb1 is None or emb2 is None:
            similarities.append(0.0)
        else:
            emb1 = emb1 / (np.linalg.norm(emb1) + 1e-10)
            emb2 = emb2 / (np.linalg.norm(emb2) + 1e-10)
            sim = cosine_similarity([emb1], [emb2])[0][0]
            similarities.append(sim)

    sim_matrix = np.array(similarities).reshape(1, -1)

    # Plot
    fig, ax = plt.subplots(figsize=(8, 2.5))  # Increased height

   # cmap = plt.cm.YlGn  # Bright for similar, dull for non-similar
    cmap = plt.cm.turbo
    from matplotlib.colors import LinearSegmentedColormap

    # Define custom gaudy colormap: red → yellow → blue → green
    custom_colors = ["red", "yellow", "blue", "green"]
    gaudy_cmap = LinearSegmentedColormap.from_list("red_yellow_blue_green", custom_colors)




    im = ax.imshow(sim_matrix, cmap=gaudy_cmap, vmin=0, vmax=1)

    # Axis labels
    ax.set_xticks(np.arange(len(series_labels)))
    ax.set_xticklabels(series_labels, fontsize=12, rotation=30, ha='right',fontweight='bold')
    ax.set_yticks([0])
    ax.set_yticklabels([f"{company1} & {company2}"], fontsize=12,fontweight='bold')

    # Annotate similarity values
    for j in range(len(series_labels)):
        sim_val = similarities[j]
        ax.text(j, 0, f"{sim_val:.2f}", ha='center', va='center',
                color='black' if sim_val < 0.85 else 'white', fontsize=12, fontweight='bold')

    # Colorbar
    cbar = fig.colorbar(im, ax=ax, orientation='vertical', shrink=0.8, pad=0.02)
    cbar.set_label("Cosine Similarity", fontsize=12)
    cbar.ax.tick_params(labelsize=12)

    # Title and layout
    #ax.set_title("Similarity Across Embedding Domains", fontsize=12, pad=10)
    plt.tight_layout(pad=2.0)
    pdf.savefig(fig)
    plt.close(fig)

from matplotlib.colors import LinearSegmentedColormap

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def ensure_numpy(embedding):
    if embedding is None:
        return None
    if hasattr(embedding, "cpu"):  # for torch tensors
        return embedding.detach().cpu().numpy()
    return embedding

def plot_similarity_across_series(company1, company2, series_dicts, series_labels, pdf):
    """
    Plots block-wise cosine similarities between two companies across multiple embedding series.

    Args:
        company1 (str): First company name.
        company2 (str): Second company name.
        series_dicts (list): List of dicts, each mapping company -> 768-d embedding for a series.
        series_labels (list): Labels for each embedding series.
        pdf (PdfPages): Open PdfPages object to save the plot.
    """
    similarities = []
    for label, d in zip(series_labels, series_dicts):
        emb1 = ensure_numpy(d.get(company1))
        emb2 = ensure_numpy(d.get(company2))
        if emb1 is None or emb2 is None:
            similarities.append(0.0)
        else:
            # Normalize each vector
            emb1 = emb1 / (np.linalg.norm(emb1) + 1e-10)
            emb2 = emb2 / (np.linalg.norm(emb2) + 1e-10)
            # Call your cosine_similarity (1D vectors only)
            sim = cosine_similarity(emb1, emb2)
            similarities.append(sim)

    similarities = np.array(similarities)
    min_sim, max_sim = similarities.min(), similarities.max()
    normalized = (similarities - min_sim) / (max_sim - min_sim + 1e-10)
    sim_matrix = normalized.reshape(1, -1)

    # Colormap (better contrast)
    gaudy_cmap = LinearSegmentedColormap.from_list(
        "red_yellow_blue_green", ["red", "yellow", "blue", "green"]
    )

    # Plot
    fig, ax = plt.subplots(figsize=(8, 2.5))
    im = ax.imshow(sim_matrix, cmap=gaudy_cmap, vmin=0, vmax=1)
    ax.set_xticks(np.arange(len(series_labels)))
    ax.set_xticklabels(series_labels, fontsize=12, rotation=30, ha='right', fontweight='bold')
    ax.set_yticks([0])
    ax.set_yticklabels([f"{company1} & {company2}"], fontsize=12, fontweight='bold')

    # Annotate raw similarities
    for j, val in enumerate(similarities):
        ax.text(j, 0, f"{val:.2f}", ha='center', va='center', color='white', fontsize=12, fontweight='bold')

    # Colorbar
    cbar = fig.colorbar(im, ax=ax, orientation='vertical', shrink=0.8, pad=0.02)
    cbar.set_label("Normalized Similarity", fontsize=12)
    cbar.ax.tick_params(labelsize=12)

    plt.tight_layout(pad=2.0)
    pdf.savefig(fig)
    plt.close(fig)

from tqdm import tqdm

series_labels = ["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"]
series_dicts = [
   esg_context_embeddings_from_model_12,
                                          esg_fo_context_embeddings_from_model_12,
                                          esg_so_context_embeddings_from_model_12,
                                          ret_context_embeddings_from_model_12,
                                          e_context_embeddings_from_model_12,
                                          s_context_embeddings_from_model_12,
                                          g_context_embeddings_from_model_12,
                                          senti_context_embeddings_from_model_12
]
from matplotlib.backends.backend_pdf import PdfPages
with PdfPages("company_similarity.pdf") as pdf:
    plot_similarity_across_series("AAPL", "MSFT", series_dicts, series_labels, pdf)

"""Although the global maximum silhouette score is at 2 clusters, we select k=6 as a trade-off between structure and interpretability. The silhouette score of 0.40 remains within acceptable range, and our hierarchical clustering (see Figure Y) confirms a natural break at 6 groups, which aligns with domain-specific ESG distinctions."""



from sklearn.metrics.pairwise import cosine_similarity

def validate_cluster_similarity(company1, company2, embedding_dict):
    vec1 = embedding_dict[company1].cpu().numpy() if hasattr(embedding_dict[company1], "cpu") else embedding_dict[company1]
    vec2 = embedding_dict[company2].cpu().numpy() if hasattr(embedding_dict[company2], "cpu") else embedding_dict[company2]

    sim = cosine_similarity([vec1], [vec2])[0][0]
    print(f"Cosine similarity between {company1} and {company2}: {sim:.4f}")
    return sim

validate_cluster_similarity("CDNS","VRSN",esg_context_embeddings_from_model)

def plot_time_series_comparison(company1, company2, esg_dict):
    import matplotlib.pyplot as plt

    def extract_series(token_str):
        values = [float(m) for m in re.findall(r"<ESG_(-?\d+(?:\.\d+)?)>", token_str)]
        return values

    series1 = extract_series(esg_dict[company1])
    series2 = extract_series(esg_dict[company2])

    plt.figure(figsize=(10, 4))
    plt.plot(series1, label=f"{company1} ESG")
    plt.plot(series2, label=f"{company2} ESG")
    plt.title("ESG Series Comparison")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_time_series_comparison_senti(company1, company2, esg_dict):
    import matplotlib.pyplot as plt

    def extract_series(token_str):
        values = [float(m) for m in re.findall(r"<SENTI_(-?\d+(?:\.\d+)?)>", token_str)]
        return values

    series1 = extract_series(esg_dict[company1])
    series2 = extract_series(esg_dict[company2])

    plt.figure(figsize=(10, 4))
    plt.plot(series1, label=f"{company1} SENTI")
    plt.plot(series2, label=f"{company2} SENTI")
    plt.title("ESG Series Comparison")
    plt.legend()
    plt.grid(True)
    plt.show()

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib
from sklearn.metrics.pairwise import cosine_distances
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from scipy.spatial import ConvexHull
import pandas as pd
import random

def cluster_and_plot_with_labels(embedding_dict, n_clusters=6, annotate_per_cluster=3):
    company_names = list(embedding_dict.keys())
    embeddings = np.array([
        embedding_dict[c].cpu().numpy() if hasattr(embedding_dict[c], "cpu") else embedding_dict[c]
        for c in company_names
    ])

    # Clustering
    cosine_dist = cosine_distances(embeddings)
    clustering = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average')
    cluster_labels = clustering.fit_predict(cosine_dist)

    # Color mapping
    cmap = cm.get_cmap('tab10', n_clusters)
    color_map = {i: matplotlib.colors.to_hex(cmap(i)) for i in range(n_clusters)}
    hex_to_name = {
        "#1f77b4": "blue", "#ff7f0e": "orange", "#2ca02c": "green", "#d62728": "red",
        "#9467bd": "purple", "#8c564b": "brown", "#e377c2": "pink", "#7f7f7f": "gray",
        "#bcbd22": "olive", "#17becf": "cyan"
    }
    cluster_to_color_name = {i: hex_to_name.get(color_map[i], color_map[i]) for i in range(n_clusters)}

    # PCA for 2D plot
    reduced = PCA(n_components=2).fit_transform(embeddings)

    # Plot
    plt.figure(figsize=(14, 10))
    for i in range(n_clusters):
        idxs = np.where(cluster_labels == i)[0]
        x = reduced[idxs, 0]
        y = reduced[idxs, 1]
        plt.scatter(x, y, color=color_map[i], label=f"Cluster {i} ({cluster_to_color_name[i]})", alpha=0.7)

        # Convex hull
        if len(idxs) >= 3:
            hull = ConvexHull(np.stack([x, y], axis=1))
            for simplex in hull.simplices:
                plt.plot(x[simplex], y[simplex], '--', color='gray', alpha=0.5)

        # Annotate a few company names
        sampled = random.sample(list(idxs), min(annotate_per_cluster, len(idxs)))
        for j in sampled:
            plt.text(reduced[j, 0], reduced[j, 1], company_names[j], fontsize=8, alpha=0.8)

    plt.title("Company Embeddings Clustered (with Selected Labels)")
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Table: all companies
    full_table = pd.DataFrame({
        "Company": company_names,
        "Cluster": cluster_labels,
        "Color": [cluster_to_color_name[c] for c in cluster_labels]
    })

    return full_table

df = cluster_and_plot_with_labels(esg_context_embeddings_from_model_12, n_clusters=6, annotate_per_cluster=3)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib
from sklearn.metrics.pairwise import cosine_distances
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from scipy.spatial import ConvexHull
import pandas as pd
import random

def find_optimal_clusters(embeddings, min_k=2, max_k=10):
    best_k = min_k
    best_score = -1
    dist_matrix = cosine_distances(embeddings)
    for k in range(min_k, max_k + 1):
        clustering = AgglomerativeClustering(n_clusters=k, metric='precomputed', linkage='average')
        labels = clustering.fit_predict(dist_matrix)
        if len(set(labels)) > 1:  # silhouette needs >1 cluster
            score = silhouette_score(embeddings, labels, metric='cosine')
            if score > best_score:
                best_k = k
                best_score = score
    return best_k, best_score

def cluster_and_plot_with_labels(embedding_dict, annotate_per_cluster=3, min_k=2, max_k=10):
    company_names = list(embedding_dict.keys())
    embeddings = np.array([
        embedding_dict[c].cpu().numpy() if hasattr(embedding_dict[c], "cpu") else embedding_dict[c]
        for c in company_names
    ])

    # === Find optimal number of clusters ===
    n_clusters, best_score = find_optimal_clusters(embeddings, min_k, max_k)
    print(f"Optimal clusters: {n_clusters} (Silhouette = {best_score:.3f})")

    # Clustering with optimal k
    cosine_dist = cosine_distances(embeddings)
    clustering = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average')
    cluster_labels = clustering.fit_predict(cosine_dist)

    # Color mapping
    cmap = cm.get_cmap('tab10', n_clusters)
    color_map = {i: matplotlib.colors.to_hex(cmap(i)) for i in range(n_clusters)}
    hex_to_name = {
        "#1f77b4": "blue", "#ff7f0e": "orange", "#2ca02c": "green", "#d62728": "red",
        "#9467bd": "purple", "#8c564b": "brown", "#e377c2": "pink", "#7f7f7f": "gray",
        "#bcbd22": "olive", "#17becf": "cyan"
    }
    cluster_to_color_name = {i: hex_to_name.get(color_map[i], color_map[i]) for i in range(n_clusters)}

    # PCA for 2D plot
    reduced = PCA(n_components=2).fit_transform(embeddings)

    # Plot
    plt.figure(figsize=(14, 10))
    for i in range(n_clusters):
        idxs = np.where(cluster_labels == i)[0]
        x = reduced[idxs, 0]
        y = reduced[idxs, 1]
        plt.scatter(x, y, color=color_map[i], label=f"Cluster {i} ({cluster_to_color_name[i]})", alpha=0.7)

        # Convex hull
        if len(idxs) >= 3:
            hull = ConvexHull(np.stack([x, y], axis=1))
            for simplex in hull.simplices:
                plt.plot(x[simplex], y[simplex], '--', color='gray', alpha=0.5)

        # Annotate a few company names
        sampled = random.sample(list(idxs), min(annotate_per_cluster, len(idxs)))
        for j in sampled:
            plt.text(reduced[j, 0], reduced[j, 1], company_names[j], fontsize=8, alpha=0.8)

    plt.title(f"Company Embeddings Clustered (Optimal k={n_clusters})")
    plt.xlabel("PCA Component 1")
    plt.ylabel("PCA Component 2")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Table: all companies
    full_table = pd.DataFrame({
        "Company": company_names,
        "Cluster": cluster_labels,
        "Color": [cluster_to_color_name[c] for c in cluster_labels]
    })

    return full_table, n_clusters, best_score

cluster_table = cluster_and_plot_with_labels_block(esg_context_embeddings_from_model, block_type="ESG", n_clusters=7)
print(cluster_table)

import pandas as pd

# Ensure full row visibility
pd.set_option('display.max_rows', None)

# Save DataFrame to a text file with full formatting
with open("esg_clusters.txt", "w") as f:
    f.write(df.to_string(index=False))  # Set index=True if you want to keep index

pd.set_option('display.max_rows', None)
print(df)

import matplotlib.pyplot as plt
import re
import numpy as np

def plot_time_series_comparison_ret(companies, esg_dict,kind):
    def extract_series(token_str):
        return [float(m) for m in re.findall(f"<{kind}_(-?\d+(?:\.\d+)?)>", token_str)]

    # Setup plot
    plt.figure(figsize=(5, 3))  # Smaller, compact plot

    # Gaudy bright color palette
    bright_colors = ['magenta', 'lime', 'cyan', 'orange', 'red', 'blue', 'gold', 'purple']

    for i, company in enumerate(companies):
        if company not in esg_dict:
            print(f"Skipping {company}: not in input.")
            continue

        series = extract_series(esg_dict[company])
        if not series:
            print(f"Skipping {company}: no f{kind} series found.")
            continue

        plt.plot(series,
                 label=f"{company} {kind}",
                 linewidth=2.5,
                 color=bright_colors[i % len(bright_colors)],
                 alpha=0.9)

    plt.title(f"{kind} Series Comparison", fontsize=12, fontweight='bold')
    plt.xlabel("Time", fontsize=10, fontweight='bold')
    plt.ylabel(f"{kind} Value", fontsize=10, fontweight='bold')

    legend = plt.legend(
    loc='upper right',                  # anchor top-left of legend

    fontsize=12,
    frameon=True
)

    # Set background and transparency
    legend.get_frame().set_facecolor('#f0f0f0')   # light gray
    legend.get_frame().set_alpha(0.6)             # 60% opacity
    legend.get_frame().set_edgecolor('gray')      # optional edge

    plt.grid(True, linestyle='--', alpha=0.7)

    # Bold tick labels
    ax = plt.gca()
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontweight('bold')
        label.set_fontsize(12)

    plt.tight_layout()
    plt.show()

def plot_time_series_comparison_Kind(company1, company2, esg_dict):
    import matplotlib.pyplot as plt

    def extract_series(token_str):
        values = [float(m) for m in re.findall(r"<RET_(-?\d+(?:\.\d+)?)>", token_str)]
        return values

    series1 = extract_series(esg_dict[company1])
    series2 = extract_series(esg_dict[company2])

    plt.figure(figsize=(10, 4))
    plt.plot(series1, label=f"{company1} ESG")
    plt.plot(series2, label=f"{company2} ESG")
    plt.title("ESG Series Comparison")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_time_series_comparison_Kind(company1, company2, esg_dict):
    import matplotlib.pyplot as plt

    def extract_series(token_str):
        values = [float(m) for m in re.findall(r"<RET_(-?\d+(?:\.\d+)?)>", token_str)]
        return values

    series1 = extract_series(esg_dict[company1])
    series2 = extract_series(esg_dict[company2])

    plt.figure(figsize=(10, 4))
    plt.plot(series1, label=f"{company1} RET")
    plt.plot(series2, label=f"{company2} RET")
    plt.title("RET Series Comparison")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_time_series_comparison_senti(company1, company2, esg_dict):
    import matplotlib.pyplot as plt

    def extract_series(token_str):
        values = [float(m) for m in re.findall(r"<SENTI_(-?\d+(?:\.\d+)?)>", token_str)]
        return values

    series1 = extract_series(esg_dict[company1])
    series2 = extract_series(esg_dict[company2])

    plt.figure(figsize=(10, 4))
    plt.plot(series1, label=f"{company1} SENTI")
    plt.plot(series2, label=f"{company2} SENTI")
    plt.title("SENTI Series Comparison")
    plt.legend()
    plt.grid(True)
    plt.show()

"""# HIGH ESG SAME RETURN"""

ret_core

senti_df_6 = cluster_and_plot_with_labels(senti_context_embeddings_from_model, n_clusters=6, annotate_per_cluster=3)

!pip install dtaidistance

import re
import numpy as np
from dtaidistance import dtw
from sklearn.metrics.pairwise import cosine_distances
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import squareform
from scipy.stats import spearmanr
import pandas as pd

# === Step 1: Extract numeric series from token strings ===
def extract_series(token_str):
    """Extract numeric values from tokenized series like <RET_01> <RET_-01> ..."""
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

# === Step 2: Compute DTW distance matrix ===
def compute_dtw_matrix(series_dict):
    companies = list(series_dict.keys())
    n = len(companies)
    dtw_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            d = dtw.distance(series_dict[companies[i]], series_dict[companies[j]])
            dtw_matrix[i, j] = dtw_matrix[j, i] = d
    return companies, dtw_matrix

# === Step 3: Compute embedding cosine distance matrix ===
def compute_embedding_distance_matrix(embedding_dict, companies):
    valid_companies = [c for c in companies if c in embedding_dict]
    embeddings = np.array([
        embedding_dict[c].cpu().numpy() if hasattr(embedding_dict[c], "cpu") else embedding_dict[c]
        for c in valid_companies
    ])
    return valid_companies, cosine_distances(embeddings)

# === Step 4: Compare matrices (Spearman) ===
def compare_matrices(dtw_matrix, embed_matrix):
    dtw_flat = squareform(dtw_matrix)
    embed_flat = squareform(embed_matrix)
    rho, _ = spearmanr(dtw_flat, embed_flat)
    return rho

# === Step 5: Compare clustering consistency ===
def compare_clusters(dtw_matrix, embed_matrix, n_clusters=5):
    dtw_clust = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average').fit(dtw_matrix)
    emb_clust = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average').fit(embed_matrix)
    ari = adjusted_rand_score(dtw_clust.labels_, emb_clust.labels_)
    nmi = normalized_mutual_info_score(dtw_clust.labels_, emb_clust.labels_)
    return ari, nmi

# === Step 6: Run experiment for all series types ===
def run_dtw_vs_embedding_experiment(series_sets, n_clusters=5):
    results = []
    for name, (series_dict, embed_dict) in series_sets.items():
        # Extract numeric series
        numeric_series = {k: extract_series(v) for k, v in series_dict.items() if v.strip() and k in embed_dict}
        if len(numeric_series) < 2:
            print(f"Skipping {name}: Not enough companies with both series and embeddings.")
            continue

        companies, dtw_matrix = compute_dtw_matrix(numeric_series)
        companies, embed_matrix = compute_embedding_distance_matrix(embed_dict, companies)

        # Align DTW to the companies that also have embeddings
        if len(companies) < 2:
            print(f"Skipping {name}: No matching companies with embeddings.")
            continue
        idxs = [list(numeric_series.keys()).index(c) for c in companies]
        dtw_matrix = dtw_matrix[np.ix_(idxs, idxs)]

        # Compare distances and clusters
        rho = compare_matrices(dtw_matrix, embed_matrix)
        ari, nmi = compare_clusters(dtw_matrix, embed_matrix, n_clusters=n_clusters)

        results.append({
            "Series": name,
            "Spearman": round(rho, 3),
            "ARI": round(ari, 3),
            "NMI": round(nmi, 3)
        })

    return pd.DataFrame(results)

# === Step 7: Define your inputs ===
series_sets = {
    "RET": (ret_core, ret_context_embeddings_from_model_12),
    "SOC": (s_core, s_context_embeddings_from_model_12),
    "GOV": (g_core, g_context_embeddings_from_model_12),
    "ESG": (esg_core, esg_context_embeddings_from_model_12),
    "ESGFO": (esg_fo, esg_fo_context_embeddings_from_model_12),
    "ESGSO": (esg_so, esg_so_context_embeddings_from_model_12),
    "ENV": (e_core, e_context_embeddings_from_model_12),
    "SENTI": (senti_core, senti_context_embeddings_from_model_12)
}

# === Step 8: Run & print summary ===
results_df = run_dtw_vs_embedding_experiment(series_sets, n_clusters=5)
print(results_df)

# Block index mapping
BLOCKS = {
    "RET": (0, 96),
    "SOC": (96, 192),
    "GOV": (192, 288),
    "ESG": (288, 384),
    "ESGFO": (384, 480),
    "ESGSO": (480, 576),
    "ENV": (576, 672),
    "SENTI": (672, 768)
}

def compute_embedding_distance_matrix_blockwise(embedding_dict, companies, block_name):
    """Compute cosine distance matrix for embeddings within a specific block."""
    start, end = BLOCKS[block_name]
    valid_companies = [c for c in companies if c in embedding_dict]
    if len(valid_companies) < 2:
        return [], np.array([])  # Not enough data
    embeddings = np.array([
        embedding_dict[c].cpu().numpy()[start:end] if hasattr(embedding_dict[c], "cpu") else embedding_dict[c][start:end]
        for c in valid_companies
    ])
    return valid_companies, cosine_distances(embeddings)

def run_dtw_vs_embedding_for_block(series_dict, embed_dict, block_name="ESG", n_clusters=5):
    # Extract numeric series only for companies also present in embeddings
    numeric_series = {k: extract_series(v) for k, v in series_dict.items() if v.strip() and k in embed_dict}
    if len(numeric_series) < 2:
        print(f"Skipping {block_name}: Not enough companies with both series and embeddings.")
        return None, None, None, [], np.array([]), np.array([])

    companies, dtw_matrix = compute_dtw_matrix(numeric_series)
    companies, embed_matrix = compute_embedding_distance_matrix_blockwise(embed_dict, companies, block_name)
    if len(companies) < 2:
        print(f"Skipping {block_name}: No valid companies for embeddings.")
        return None, None, None, [], np.array([]), np.array([])

    # Align DTW matrix to the companies that also have embeddings
    idxs = [list(numeric_series.keys()).index(c) for c in companies]
    dtw_matrix = dtw_matrix[np.ix_(idxs, idxs)]

    # Compare distances and clusters
    rho = compare_matrices(dtw_matrix, embed_matrix)
    ari, nmi = compare_clusters(dtw_matrix, embed_matrix, n_clusters=n_clusters)

    print(f"Block: {block_name}")
    print(f"Spearman: {rho:.3f}")
    print(f"ARI: {ari:.3f}")
    print(f"NMI: {nmi:.3f}")

    return rho, ari, nmi, companies, dtw_matrix, embed_matrix

# Run for ESG block (using only dimensions 288–383)
rho, ari, nmi, companies, dtw_matrix, embed_matrix = run_dtw_vs_embedding_for_block(
    esg_core,
    esg_context_embeddings_from_model_12,
    block_name="ESG",
    n_clusters=5
)

# Run for ESG block (using only dimensions 288–383)
rho, ari, nmi, companies, dtw_matrix, embed_matrix = run_dtw_vs_embedding_for_block(
    ret_core,
    ret_context_embeddings_from_model_12,
    block_name="RET",
    n_clusters=5
)

import re
import numpy as np
from dtaidistance import dtw
from sklearn.metrics.pairwise import cosine_distances
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import squareform
from scipy.stats import spearmanr
import pandas as pd

# === Block index mapping (768D sliced by domain) ===
BLOCKS = {
    "RET": (0, 96),
    "SOC": (96, 192),
    "GOV": (192, 288),
    "ESG": (288, 384),
    "ESGFO": (384, 480),
    "ESGSO": (480, 576),
    "ENV": (576, 672),
    "SENTI": (672, 768)
}

# === Extract numeric series ===
def extract_series(token_str):
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

# === Compute DTW distance matrix (skip problematic pairs) ===
def compute_dtw_matrix(series_dict):
    companies = list(series_dict.keys())
    n = len(companies)
    dtw_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            try:
                d = dtw.distance(series_dict[companies[i]], series_dict[companies[j]])
                dtw_matrix[i, j] = dtw_matrix[j, i] = d
            except Exception:
                dtw_matrix[i, j] = dtw_matrix[j, i] = np.nan
    return companies, dtw_matrix

# === Compute cosine distance for relevant block slice (skip invalid embeddings) ===
def compute_embedding_distance_matrix_blockwise(embedding_dict, companies, block_name):
    start, end = BLOCKS[block_name]
    valid_embeddings = []
    valid_companies = []
    for c in companies:
        try:
            val = embedding_dict[c]
            if hasattr(val, "cpu"):
                val = val.cpu().numpy()
            valid_embeddings.append(val[start:end])
            valid_companies.append(c)
        except Exception:
            continue
    if len(valid_embeddings) < 2:
        return [], None
    return valid_companies, cosine_distances(np.array(valid_embeddings))

# === Spearman correlation of distance matrices ===
def compare_matrices(dtw_matrix, embed_matrix):
    dtw_flat = squareform(dtw_matrix, checks=False)
    embed_flat = squareform(embed_matrix, checks=False)
    rho, _ = spearmanr(dtw_flat, embed_flat, nan_policy="omit")
    return rho

# === Compare clustering consistency (ARI & NMI) ===
def compare_clusters(dtw_matrix, embed_matrix, n_clusters=5):
    try:
        dtw_clust = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average').fit(dtw_matrix)
        emb_clust = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average').fit(embed_matrix)
        ari = adjusted_rand_score(dtw_clust.labels_, emb_clust.labels_)
        nmi = normalized_mutual_info_score(dtw_clust.labels_, emb_clust.labels_)
        return ari, nmi
    except Exception:
        return np.nan, np.nan

# === Run for one block ===
def run_dtw_vs_embedding_for_block(series_dict, embed_dict, block_name, n_clusters=5):
    numeric_series = {k: extract_series(v) for k, v in series_dict.items() if v.strip()}
    if not numeric_series:
        return np.nan, np.nan, np.nan
    companies, dtw_matrix = compute_dtw_matrix(numeric_series)
    companies, embed_matrix = compute_embedding_distance_matrix_blockwise(embed_dict, companies, block_name)
    if not companies or embed_matrix is None:
        return np.nan, np.nan, np.nan
    rho = compare_matrices(dtw_matrix, embed_matrix)
    ari, nmi = compare_clusters(dtw_matrix, embed_matrix, n_clusters=n_clusters)
    return round(rho, 3) if rho is not None else np.nan, round(ari, 3) if ari is not None else np.nan, round(nmi, 3) if nmi is not None else np.nan

# === Run for all blocks ===
def run_dtw_vs_embedding_all_blocks(series_sets, n_clusters=5):
    results = []
    for name, (series_dict, embed_dict) in series_sets.items():
        rho, ari, nmi = run_dtw_vs_embedding_for_block(series_dict, embed_dict, name, n_clusters=n_clusters)
        results.append({"Series": name, "Spearman": rho, "ARI": ari, "NMI": nmi})
    return pd.DataFrame(results)

# === Example series sets ===
series_sets = {
    "RET": (ret_core, ret_context_embeddings_from_model_12),
    "SOC": (s_core, s_context_embeddings_from_model_12),
    "GOV": (g_core, g_context_embeddings_from_model_12),
    "ESG": (esg_core, esg_context_embeddings_from_model_12),
    "ESGFO": (esg_fo, esg_fo_context_embeddings_from_model_12),
    "ESGSO": (esg_so, esg_so_context_embeddings_from_model_12),
    "ENV": (e_core, e_context_embeddings_from_model_12),
    "SENTI": (senti_core, senti_context_embeddings_from_model_12)
}

# === Run & print ===
results_df = run_dtw_vs_embedding_all_blocks(series_sets, n_clusters=5)
print(results_df)

import pandas as pd

# Ensure full row visibility
pd.set_option('display.max_rows', None)

# Save DataFrame to a text file with full formatting
with open("senti_df_6.txt", "w") as f:
    f.write(senti_df_6.to_string(index=False))  # Set index=True if you want to keep index

senti_df_6

import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

import re
import numpy as np

def extract_avg_esg_from_string(token_string):
    """
    Given a single space-separated token string like:
    "<ESG_56.00> <ESG_57.22> <ESG_31.07>", extract and return average ESG value.
    """
    #print(token_string)
    matches = re.findall(r"<ESG_(-?\d+(?:\.\d+)?)>", token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None


def find_esg_superior_ret_similar_peers(
    target_company,
    ret_embedding_dict,
    esg_token_dict,

    top_k=20
):
    # Ensure input presence
    if target_company not in ret_embedding_dict or target_company not in esg_token_dict:
        raise ValueError("Target company missing in one or both dictionaries.")

    # Prepare vector
    def to_np(x): return x.cpu().numpy() if hasattr(x, "cpu") else x
    target_vec = to_np(ret_embedding_dict[target_company]).reshape(1, -1)
    target_esg = extract_avg_esg_from_string(esg_token_dict[target_company])
    #target_ret = extract_avg_kind_from_string(esg_token_dict[target_company])
    ret_val = extract_avg_kind_from_string(ret_core[target_company],"RET")

    if target_esg is None:
        raise ValueError("No valid ESG token found for target.")

    # Compute cosine similarity
    peers = []
    for company, emb in ret_embedding_dict.items():
        if company == target_company or company not in esg_token_dict:
            continue
        vec = to_np(emb).reshape(1, -1)
        sim = cosine_similarity(target_vec, vec)[0, 0]
        esg_val = extract_avg_esg_from_string(esg_token_dict[company])
        ret_val_new = extract_avg_kind_from_string(ret_core[company],"RET")

        if esg_val is not None and esg_val < target_esg and ret_val_new>=  ret_val:
            peers.append((company, sim, esg_val,ret_val_new))

    # Sort by similarity
    peers = sorted(peers, key=lambda x: x[1], reverse=True)

    return peers[:top_k]

def extract_avg_kind_from_string(token_string,kind):
    """
    Given a single space-separated token string like:
    "<ESG_56.00> <ESG_57.22> <ESG_31.07>", extract and return average ESG value.
    """
    #print(token_string)
    matches = re.findall(f"<{kind}_(-?\d+(?:\.\d+)?)>", token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None

extract_avg_kind_from_string(esg_core["TSLA"],"ESG")

extract_avg_kind_from_string(e_core["TSLA"],"ENV")

ret_core['AAPL']

peers = find_esg_superior_ret_similar_peers(
    "TSLA",
    ret_context_embeddings_from_model_12,
    esg_core,
    top_k=10
)

for company, sim, esg,ret in peers:
    print(f"{company}: return_sim = {sim:.4f}, ESG = {esg:.2f} , RET = {ret:.2f}")

"""SIMILAR COMPANY ANALYSIS

> Add blockquote


"""

import time



from contextlib import redirect_stdout

with open("superior_2d_2.txt", "w") as f:
 with redirect_stdout(f):

  for c in companies:
       try:
          print(f"For Company: {c} ESG: {extract_avg_esg_from_string(esg_core[c])}  RET: {extract_avg_kind_from_string(ret_core[c],'RET')}")

          peers = find_esg_superior_ret_similar_peers(
              c,
              ret_context_embeddings_from_model,
              esg_core,
              extract_avg_kind_from_string(ret_core[c],"RET"),
              top_k=10,
          )
          print(peers)
          #for company, sim, esg in peers:
              #print(f"{company}: return_sim = {sim:.4f}, ESG = {esg:.2f}")
       except:
        print("------------------------SKIP-------------------------------",c)
        continue
       print("---------------------------DONE----------------------------")

from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def find_similar_companies(target_company, embedding_dict, top_k=10):
    """
    Given a company, returns the top_k most similar companies based on cosine similarity.

    Args:
        target_company (str): The company name to compare.
        embedding_dict (dict): {company: embedding (tensor or np array)}
        top_k (int): Number of most similar companies to return.

    Returns:
        List of tuples: [(company_name, similarity_score), ...]
    """
    # Ensure the target is present
    if target_company not in embedding_dict:
        raise ValueError(f"Company '{target_company}' not found in embedding dictionary.")

    def to_np(x):
        return x.cpu().numpy() if hasattr(x, "cpu") else x

    target_vector = to_np(embedding_dict[target_company]).reshape(1, -1)
    all_companies = list(embedding_dict.keys())
    all_vectors = np.array([to_np(embedding_dict[c]) for c in all_companies])

    # Compute cosine similarity
    similarities = cosine_similarity(target_vector, all_vectors).flatten()

    # Zip and sort
    company_scores = list(zip(all_companies, similarities))
    company_scores = sorted(company_scores, key=lambda x: x[1], reverse=True)

    # Skip the first result if it's the same company
    similar = [(name, score) for name, score in company_scores if name != target_company]

    return similar[:top_k]

"""METRIC

Experiment - Compare Portfolio Optimization results with Competitative Peers Now  similar to above give a function to get companies with better esg score and better sentiment  but similar returns. all ideal vectors needed you have to create within function

HIGH ESG AND HIGH SENTIMENT
"""

def extract_avg_senti_from_string(token_string):
    matches = re.findall(r"<SENTI_(-?\d+(?:\.\d+)?)>", token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None


def find_esg_senti_superior_ret_similar_peers(
    target_company,
    ret_embedding_dict,
    esg_token_dict,
    senti_token_dict,
    top_k=10
):
    # Ensure input presence
    if (
        target_company not in ret_embedding_dict or
        target_company not in esg_token_dict or
        target_company not in senti_token_dict
    ):
        raise ValueError("Target company missing in one or more dictionaries.")

    def to_np(x): return x.cpu().numpy() if hasattr(x, "cpu") else x
    target_vec = to_np(ret_embedding_dict[target_company]).reshape(1, -1)
    target_esg = extract_avg_esg_from_string(esg_token_dict[target_company])
    target_senti = extract_avg_senti_from_string(senti_token_dict[target_company])

    if target_esg is None or target_senti is None:
        raise ValueError("Invalid ESG or sentiment token for target.")

    peers = []
    for company, vec in ret_embedding_dict.items():
        if (
            company == target_company or
            company not in esg_token_dict or
            company not in senti_token_dict
        ):
            continue

        sim = cosine_similarity(target_vec, to_np(vec).reshape(1, -1))[0, 0]
        esg_val = extract_avg_esg_from_string(esg_token_dict[company])
        senti_val = extract_avg_senti_from_string(senti_token_dict[company])

        if esg_val is not None and senti_val is not None:
            # LOWER ESG and SENTI = better
            if esg_val < target_esg and senti_val > target_senti:
                peers.append((company, float(sim), float(esg_val), float(senti_val) ))

    # Sort by similarity
    peers = sorted(peers, key=lambda x: x[1], reverse=True)

    return peers[:top_k]

def extract_avg_senti_from_string(token_string):
    matches = re.findall(r"<SENTI_(-?\d+(?:\.\d+)?)>", token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None


def find_esg_senti_superior_ret_similar_peers_with_ret(
    target_company,
    ret_embedding_dict,
    esg_token_dict,
    senti_token_dict,
    expected_ret,
    top_k=10
):
    # Ensure input presence
    if (
        target_company not in ret_embedding_dict or
        target_company not in esg_token_dict or
        target_company not in senti_token_dict
    ):
        raise ValueError("Target company missing in one or more dictionaries.")

    def to_np(x): return x.cpu().numpy() if hasattr(x, "cpu") else x
    target_vec = to_np(ret_embedding_dict[target_company]).reshape(1, -1)
    target_esg = extract_avg_esg_from_string(esg_token_dict[target_company])
    target_senti = extract_avg_senti_from_string(senti_token_dict[target_company])

    if target_esg is None or target_senti is None:
        raise ValueError("Invalid ESG or sentiment token for target.")

    peers = []
    for company, vec in ret_embedding_dict.items():
        if (
            company == target_company or
            company not in esg_token_dict or
            company not in senti_token_dict
        ):
            continue

        sim = cosine_similarity(target_vec, to_np(vec).reshape(1, -1))[0, 0]
        esg_val = extract_avg_esg_from_string(esg_token_dict[company])
        senti_val = extract_avg_senti_from_string(senti_token_dict[company])
        ret_val = extract_avg_kind_from_string(ret_core[company],"RET")
        if esg_val is not None and senti_val is not None:
            # LOWER ESG and SENTI = better
            if esg_val < target_esg and senti_val > target_senti and ret_val >=expected_ret:
                peers.append((company, float(sim), float(esg_val), float(senti_val),round(float(ret_val), 2) ))

    # Sort by similarity
    peers = sorted(peers, key=lambda x: x[1], reverse=True)

    return peers[:top_k]

def get_block(vec, start, end):
    """Extracts a subvector (block) from a full embedding."""
    return vec[start:end]

def extract_avg_esg_from_string(token_string):
    matches = re.findall(r"<ESG_(-?\d+(?:\.\d+)?)>", token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None

def extract_avg_senti_from_string(token_string):
    matches = re.findall(r"<SENTI_(-?\d+(?:\.\d+)?)>", token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None

def extract_avg_kind_from_string(token_string, kind):
    pattern = rf"<{kind}_(-?\d+(?:\.\d+)?)>"
    matches = re.findall(pattern, token_string)
    values = [float(m) for m in matches]
    return np.mean(values) if values else None

def find_esg_senti_superior_ret_similar_peers_with_ret(
    target_company,
    ret_embedding_dict,
    esg_token_dict,
    senti_token_dict,
    ret_core,
    expected_ret,
    top_k=10
):
    if (
        target_company not in ret_embedding_dict or
        target_company not in esg_token_dict or
        target_company not in senti_token_dict
    ):
        raise ValueError("Target company missing in one or more dictionaries.")

    def to_np(x): return x.cpu().numpy() if hasattr(x, "cpu") else x

    # === Extract RET block for target ===
    target_vec = get_block(to_np(ret_embedding_dict[target_company]), 0, 96).reshape(1, -1)
    target_esg = extract_avg_esg_from_string(esg_token_dict[target_company])
    target_senti = extract_avg_senti_from_string(senti_token_dict[target_company])

    if target_esg is None or target_senti is None:
        raise ValueError("Invalid ESG or sentiment token for target.")

    peers = []
    for company, vec in ret_embedding_dict.items():
        if (
            company == target_company or
            company not in esg_token_dict or
            company not in senti_token_dict
        ):
            continue

        # === Use only RET block for similarity ===
        peer_vec = get_block(to_np(vec), 0, 96).reshape(1, -1)
        sim = cosine_similarity(target_vec, peer_vec)[0, 0]

        esg_val = extract_avg_esg_from_string(esg_token_dict[company])
        senti_val = extract_avg_senti_from_string(senti_token_dict[company])
        ret_val = extract_avg_kind_from_string(ret_core[company], "RET")

        if esg_val is not None and senti_val is not None and ret_val is not None:
            # LOWER ESG = better, HIGHER sentiment = better, RET >= expected
            if esg_val < target_esg and senti_val > target_senti and ret_val >= expected_ret:
                peers.append((company, float(sim), float(esg_val), float(senti_val), round(float(ret_val), 2)))

    # Sort by similarity (descending)
    peers = sorted(peers, key=lambda x: x[1], reverse=True)
    return peers[:top_k]

def find_esg_senti_superior_ret_similar_peers_with_ret(
    target_company, --AAPL
    ret_embedding_dict, -ret_context_embeddings_from_model
    esg_token_dict, --esg_core
    senti_token_dict, --senti_core
    ret_core,
    expected_ret,
    top_k=10
):

find_esg_senti_superior_ret_similar_peers_with_ret("AAPL", ret_context_embeddings_from_model, esg_core,senti_core, ret_core,13.6)

from contextlib import redirect_stdout

with open("superior_3d_new.txt", "w") as f:
 with redirect_stdout(f):

  for c in companies:
        try:
          print(f"For Company: {c} ESG: {extract_avg_esg_from_string(esg_core[c])}  SENTI: {extract_avg_kind_from_string(senti_core[c],'SENTI')} RET: {extract_avg_kind_from_string(ret_core[c],'RET')} ")

          peers = find_esg_senti_superior_ret_similar_peers_with_ret(c, ret_context_embeddings_from_model, esg_core,senti_core, ret_core,extract_avg_kind_from_string(ret_core[c],"RET"))

          print(peers)
        except:
          print("------------------------SKIP-------------------------------")
          continue
        print("---------------------------DONE----------------------------")

def unconditional_find_esg_senti_superior_ret_similar_peers_with_ret(
    target_company,
    ret_embedding_dict,
    esg_token_dict,
    senti_token_dict,
    ret_core,
    expected_ret,
    top_k=10
):
    if (
        target_company not in ret_embedding_dict or
        target_company not in esg_token_dict or
        target_company not in senti_token_dict
    ):
        raise ValueError("Target company missing in one or more dictionaries.")

    def to_np(x): return x.cpu().numpy() if hasattr(x, "cpu") else x

    # === Extract RET block for target ===
    target_vec = get_block(to_np(ret_embedding_dict[target_company]), 0, 96).reshape(1, -1)
    target_esg = extract_avg_esg_from_string(esg_token_dict[target_company])
    target_senti = extract_avg_senti_from_string(senti_token_dict[target_company])

    if target_esg is None or target_senti is None:
        raise ValueError("Invalid ESG or sentiment token for target.")

    peers = []
    for company, vec in ret_embedding_dict.items():
        if (
            company == target_company or
            company not in esg_token_dict or
            company not in senti_token_dict
        ):
            continue

        # === Use only RET block for similarity ===
        peer_vec = get_block(to_np(vec), 0, 96).reshape(1, -1)
        sim = cosine_similarity(target_vec, peer_vec)[0, 0]

        esg_val = extract_avg_esg_from_string(esg_token_dict[company])
        senti_val = extract_avg_senti_from_string(senti_token_dict[company])
        ret_val = extract_avg_kind_from_string(ret_core[company], "RET")

        if esg_val is not None and senti_val is not None and ret_val is not None:
            # LOWER ESG = better, HIGHER sentiment = better, RET >= expected
            #if esg_val < target_esg and senti_val > target_senti and ret_val >= expected_ret:
                peers.append((company, float(sim), float(esg_val), float(senti_val), round(float(ret_val), 2)))

    # Sort by similarity (descending)
    peers = sorted(peers, key=lambda x: x[1], reverse=True)
    return peers[:top_k]

unconditional_find_esg_senti_superior_ret_similar_peers_with_ret("AAPL", ret_context_embeddings_from_model, esg_core,senti_core, ret_core,13.6)



!pip install finnhub-python

import finnhub

def get_media_articles(ticker: str) -> str:
    finnhub_client = finnhub.Client(api_key="d11s4ghr01qjtpe8r4o0d11s4ghr01qjtpe8r4og")

    return finnhub_client.company_news(ticker, _from="2025-05-01", to="2025-06-10")

!pip install langchain_community

from langchain.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, Tool
from langchain.agents.agent_types import AgentType
from langchain.tools import tool
import re

from langchain.schema import SystemMessage, HumanMessage

import os
os.environ["OPENAI_API_KEY"] = "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"

llm = ChatOpenAI(model="gpt-4.1", temperature=0.3)

def generate_explaination(company, predicted_esg):
    from openai import OpenAI
    import random

    # Step 1: Prepare context
    context = f"""
    You are an ESG analysis expert. Given the following information about a company:

ESG risk Time Series
Example: {esg_core[company]}

First Order Differences (Fo = change in ESG risk)
Example: {esg_fo[company]}

Second Order Differences (So = acceleration of change)
Example: {esg_so[company]}

Sentiment Series (monthly average sentiment risk)
Example: {senti_core[company]}

Latest News Articles
Example: {get_media_articles(company)}

Predicted ESG risk: {predicted_esg}

🔍 Based on this data, explain:

Why the predicted ESG risk is {predicted_esg}

Whether the ESG trend is improving, declining, or fluctuating

How sentiment and news may have influenced this prediction

If any turning points or patterns are visible in the ESG evolution

Be analytical and provide a concise ESG trend commentary suitable for ESG analysts and investors.
    """

    #print(user_input)

    # Step 2: Query LLM
    messages = [
    SystemMessage(content=context),
    HumanMessage(content="Explain and Justify the predicted ESG Risk given")
]

    response = llm(messages)
    print(response.content)

generate_explaination('AAPL',18.55)

generate_explaination('ABBV',24.14)

"""SIMULATOR"""



"""# POPULATION

```
# This is formatted as code
```


"""



def plot_time_series_comparison_ret_new(companies, esg_dict, kind):
    def extract_series(token_str):
        return [float(m) for m in re.findall(f"<{kind}_(-?\d+(?:\.\d+)?)>", token_str)]

    plt.figure(figsize=(5, 3))  # Wider for better visibility

    # Colorblind-friendly bright palette
    bright_colors = ['#E41A1C', '#377EB8', '#4DAF4A', '#984EA3',
                     '#FF7F00', '#A65628', '#F781BF', '#999999']

    for i, company in enumerate(companies):
        if company not in esg_dict:
            print(f"Skipping {company}: not in input.")
            continue

        series = extract_series(esg_dict[company])
        if not series:
            print(f"Skipping {company}: no {kind} series found.")
            continue

        plt.plot(series,
                 label=f"{company} {kind}",
                 linewidth=3.2,              # Thicker continuous line
                 linestyle='-',              # Solid only
                 color=bright_colors[i % len(bright_colors)],
                 alpha=0.95)

    plt.title(f"{kind} Series Comparison", fontsize=14, fontweight='bold')
    plt.xlabel("Time", fontsize=12, fontweight='bold')
    plt.ylabel(f"{kind} Value", fontsize=12, fontweight='bold')

    legend = plt.legend(
        loc='upper right',
        fontsize=12,
        frameon=True
    )
    legend.get_frame().set_facecolor('#f0f0f0')
    legend.get_frame().set_alpha(0.6)
    legend.get_frame().set_edgecolor('gray')

    plt.grid(True, linestyle='--', alpha=0.5)

    ax = plt.gca()
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontweight('bold')
        label.set_fontsize(11)

    plt.tight_layout()
    plt.show()

comp_to_compare =["TMUS","TEVA"]
plot_time_series_comparison_ret_new(comp_to_compare, ret_core,"RET")

comp_to_compare =["PKG","GDDY"]
plot_time_series_comparison_ret_new(comp_to_compare, ret_core,"RET")

from sklearn.metrics.pairwise import cosine_distances

# Ensure tensors are on CPU and converted to NumPy
tmus = ret_context_embeddings_from_model["TMUS"].cpu().numpy()
teva = ret_context_embeddings_from_model["TEVA"].cpu().numpy()

# Compute cosine distance
distance = cosine_distances([tmus], [teva])[0][0]
print(f"Cosine Distance: {distance:.4f}")

from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity([tmus], [teva])[0][0]
print(f"Cosine Similarity: {similarity:.4f}")

from sklearn.metrics.pairwise import cosine_distances

# Ensure tensors are on CPU and converted to NumPy
tmus = ret_context_embeddings_from_model["PKG"].cpu().numpy()
teva = ret_context_embeddings_from_model["GDDY"].cpu().numpy()

# Compute cosine distance
distance = cosine_distances([tmus], [teva])[0][0]
print(f"Cosine Distance: {distance:.4f}")

from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity([tmus], [teva])[0][0]
print(f"Cosine Similarity: {similarity:.4f}")

import numpy as np
from sklearn.metrics.pairwise import cosine_distances
from itertools import combinations

# Ensure all embeddings are numpy arrays
embedding_dict = {
    company: emb.cpu().numpy() if hasattr(emb, "cpu") else emb
    for company, emb in ret_context_embeddings_from_model.items()
}

# Compute and store all pairwise distances
pairwise_distances = []
for (c1, e1), (c2, e2) in combinations(embedding_dict.items(), 2):
    dist = cosine_distances([e1], [e2])[0][0]
    pairwise_distances.append(((c1, c2), dist))

# Sort by increasing distance (most similar first)
pairwise_distances.sort(key=lambda x: x[1])

# Write to file
output_path = "pairwise_cosine_distances.txt"
with open(output_path, "w") as f:
    for (c1, c2), dist in pairwise_distances:
        f.write(f"({c1}, {c2}) → Cosine Distance: {dist:.4f}\n")

print(f"Saved to {output_path}")

import matplotlib.pyplot as plt
import numpy as np
import re

def plot_first_order_difference_ret(company1, company2, esg_dict):
    """
    Plot first-order differences of RET token series for two companies
    with gaudy colors, small figure, thin lines, and bold font.
    """
    def extract_series(token_str):
        return [float(m) for m in re.findall(r"<RET_(-?\d+(?:\.\d+)?)>", token_str)]

    # Compute first-order differences
    series1 = np.diff(extract_series(esg_dict[company1]))
    series2 = np.diff(extract_series(esg_dict[company2]))

    # Gaudy colors
    color1 = "#e41a1c"  # bright red
    color2 = "#377eb8"  # bright blue

    # Plot setup
    plt.figure(figsize=(6, 3))
    plt.plot(series1, label=f"{company1} ΔRET", color=color1, linewidth=1.2)
    plt.plot(series2, label=f"{company2} ΔRET", color=color2, linewidth=1.2)

    plt.title("First-Order Difference of RET Series", fontsize=12, fontweight='bold')
    plt.xlabel("Time Step", fontsize=12, fontweight='bold')
    plt.ylabel("ΔRET", fontsize=12, fontweight='bold')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=12, frameon=True, loc='best')
    plt.xticks(fontsize=12, fontweight='bold')
    plt.yticks(fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()

"""SIMILAR"""

comp_to_compare =["FRT", "EVT","KKR","APP"]
plot_time_series_comparison_ret_new(comp_to_compare, ret_core,"RET")

plot_first_order_difference_ret("FRT", "EVT", ret_core)

comp_to_compare =["EQIX", "EA"]
plot_time_series_comparison_ret_new(comp_to_compare, ret_core,"RET")

"""DISSIMILAR"""

plot_first_order_difference_ret("EQIX", "EA", ret_core)



comp_to_compare =["KKR", "APP"]
plot_time_series_comparison_ret_new(comp_to_compare, ret_core,"RET")

plot_first_order_difference_ret("KKR", "APP", ret_core)





comp_to_compare =["DINRF", "NXLLF"]
plot_time_series_comparison_ret_new(comp_to_compare, ret_core,"RET")

plot_first_order_difference_ret("DINRF", "NXLLF", ret_core)

tmus_diff = np.diff(tmus)
teva_diff = np.diff(teva)
cosine_similarity([tmus_diff], [teva_diff])[0][0]

!pip install dtaidistance

import numpy as np
import re
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr, pearsonr
from dtaidistance import dtw

# === Helper: Parse numbers from token strings ===
def parse_numeric_series(token_str):
    return [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", token_str)]

# === Compute pairwise distances (Cosine for embeddings, DTW for series) ===
def compute_distances(embeddings_dict, token_series_dict):
    valid_companies = [c for c, t in token_series_dict.items() if len(parse_numeric_series(t)) > 0]
    if not valid_companies:
        return None, None, None, None, None

    # Collect embeddings
    embeddings = np.array([embeddings_dict[c].detach().cpu().numpy() for c in valid_companies])
    # Collect raw series
    series_data = {c: parse_numeric_series(token_series_dict[c]) for c in valid_companies}

    # Embedding cosine distances
    embedding_dist = squareform(pdist(embeddings, metric='cosine'))

    # DTW distances for raw series
    ts_matrix = np.zeros((len(valid_companies), len(valid_companies)))
    for i, ci in enumerate(valid_companies):
        for j, cj in enumerate(valid_companies):
            if i < j:
                ts_matrix[i, j] = ts_matrix[j, i] = dtw.distance_fast(np.array(series_data[ci]), np.array(series_data[cj]))

    # Flatten upper triangle
    mask = np.triu_indices(len(valid_companies), k=1)
    embedding_flat = embedding_dist[mask]
    ts_flat = ts_matrix[mask]

    # Correlations
    spearman_corr, _ = spearmanr(embedding_flat, ts_flat)
    pearson_corr, _ = pearsonr(embedding_flat, ts_flat)

    return spearman_corr, pearson_corr, embedding_flat, ts_flat, valid_companies

# === Plot regression line only (AAAI-ready) ===
def plot_regression_line(ax, ts_flat, embedding_flat, spearman_corr, pearson_corr, title):
    sns.regplot(x=ts_flat, y=embedding_flat, scatter=False, line_kws={'color':'red','linewidth':2}, ci=None, ax=ax)
    ax.set_title(title, fontsize=12)
    ax.set_xlabel("TS Distance (DTW)")
    ax.set_ylabel("Embedding Distance (Cosine)")
    ax.text(0.05, 0.95, f"Spearman: {spearman_corr:.3f}\nPearson: {pearson_corr:.3f}",
            transform=ax.transAxes, fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black"))

# === Master Function: Loop through all token spaces ===
def analyze_all_spaces(spaces):
    results = {}
    n = len(spaces)
    fig, axes = plt.subplots(1, n, figsize=(6*n, 5))
    if n == 1: axes = [axes]

    for idx, (name, emb_dict, tok_dict) in enumerate(spaces):
        spearman_corr, pearson_corr, embedding_flat, ts_flat, valid_companies = compute_distances(emb_dict, tok_dict)
        if spearman_corr is None:
            continue
        results[name] = {"Spearman": spearman_corr, "Pearson": pearson_corr}
        plot_regression_line(axes[idx], ts_flat, embedding_flat, spearman_corr, pearson_corr, f"{name} (N={len(valid_companies)})")

    plt.tight_layout()
    plt.show()

    # Print table
    print("\n=== Correlation Table ===")
    print("{:<10} {:<10} {:<10}".format("Token", "Spearman", "Pearson"))
    for name, vals in results.items():
        print("{:<10} {:<10.3f} {:<10.3f}".format(name, vals["Spearman"], vals["Pearson"]))

    return results

spaces = [
    ("RET", ret_context_embeddings_from_model, ret_core),
    ("ESG", esg_context_embeddings_from_model, esg_core),
    ("ESGFO", esg_fo_context_embeddings_from_model, esg_fo),
    ("ESGSO", esg_so_context_embeddings_from_model, esg_so),
    ("GOV", g_context_embeddings_from_model, g_core),
    ("SOC", s_context_embeddings_from_model, s_core),
    ("SENTI", senti_context_embeddings_from_model, senti_core)
]

results = analyze_all_spaces(spaces)

import torch
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr, pearsonr
from dtaidistance import dtw
import re

# === Your block setup ===
block_sizes = [96, 96, 96, 96, 96, 96, 96, 96]
block_names = ["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"]

# === Helper: group by company ===
def group_by_company(samples):
    from collections import defaultdict
    company_lines = defaultdict(list)
    for line in samples:
        match = re.search(r"Company:\s*([A-Za-z0-9.]+)", line)
        if match:
            company_lines[match.group(1)].append(line)
    return company_lines

# === Step 1: Extract embeddings for each company ===
def extract_embeddings(model, tokenizer, samples, device="cuda"):
    model.eval()
    grouped = group_by_company(samples)
    embeddings = {}
    with torch.no_grad():
        for company, lines in grouped.items():
            line_embeds = []
            for text in lines:
                tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
                outputs = model(**tokens, output_hidden_states=True)
                last_hidden = outputs.hidden_states[-1]  # [1, seq, hidden]
                pooled = last_hidden.mean(dim=1).squeeze(0)  # [hidden_dim]
                line_embeds.append(pooled.cpu())
            embeddings[company] = torch.stack(line_embeds).mean(dim=0)
    return embeddings, grouped

# === Step 2: Compute DTW distances ===
def compute_dtw(series_dict):
    companies = list(series_dict.keys())
    ts_matrix = np.zeros((len(companies), len(companies)))
    for i, ci in enumerate(companies):
        seq_i = [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", " ".join(series_dict[ci]))]
        for j, cj in enumerate(companies):
            if i < j:
                seq_j = [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", " ".join(series_dict[cj]))]
                ts_matrix[i, j] = ts_matrix[j, i] = dtw.distance_fast(seq_i, seq_j)
    ts_matrix /= ts_matrix.max() if ts_matrix.max() > 0 else 1
    return ts_matrix, companies

# === Step 3: Compute per-block correlations ===
def compute_blockwise_correlations(embeddings, ts_matrix, companies):
    embed_mat = np.array([embeddings[c].numpy() for c in companies])
    start = 0
    results = {}
    for idx, size in enumerate(block_sizes):
        end = start + size
        block = embed_mat[:, start:end]
        block_dist = squareform(pdist(block, metric="cosine"))
        mask = np.triu_indices(len(companies), k=1)
        spearman_corr, _ = spearmanr(block_dist[mask], ts_matrix[mask])
        pearson_corr, _ = pearsonr(block_dist[mask], ts_matrix[mask])
        results[block_names[idx]] = {"Spearman": spearman_corr, "Pearson": pearson_corr}
        start = end
    return results

# Assume: model (trained), tokenizer, data_set_tokenized_samples (list of all lines)
embeddings, grouped_texts = extract_embeddings(model, tokenizer_extended, data_set_tokenized_samples, device=device)
ts_matrix, companies = compute_dtw(grouped_texts)
results = compute_blockwise_correlations(embeddings, ts_matrix, companies)

# Print nicely
print("{:<10} {:<10} {:<10}".format("Block", "Spearman", "Pearson"))
for block, vals in results.items():
    print("{:<10} {:<10.3f} {:<10.3f}".format(block, vals['Spearman'], vals['Pearson']))

def extract_company_embeddings(model, tokenizer, company_texts, device="cuda"):
    """
    Extract mean-pooled last-layer embeddings for each company.
    Args:
        model: GPT-2 / LoRA model
        tokenizer: matching tokenizer
        company_texts: dict {company: text}
    Returns:
        dict {company: torch.Tensor([hidden_dim])}
    """
    model.eval()
    embeddings_dict = {}
    with torch.no_grad():
        for company, text in company_texts.items():
            tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
            outputs = model(**tokens, output_hidden_states=True)
            last_hidden = outputs.hidden_states[-1]  # [1, seq_len, hidden_dim]
            pooled = last_hidden.mean(dim=1).squeeze(0)  # [hidden_dim]
            embeddings_dict[company] = pooled.cpu()
    return embeddings_dict

data_set_tokenized_samples

embeddings_dict = extract_company_embeddings(model,tokenizer_extended,data_set_tokenized_samples)

results = analyze_blockwise_spaces(embeddings_dict, token_series_dict, title="Final Model")

import numpy as np
import re
from scipy.spatial.distance import pdist, squareform
from dtaidistance import dtw
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

# === Helper: Parse numbers from token strings ===
def parse_numeric_series(token_str):
    return [float(x) for x in re.findall(r"[-+]?\d*\.\d+|\d+", token_str)]

# === Compute pairwise distances ===
def compute_distance_matrices(embeddings_dict, token_series_dict):
    valid_companies = [c for c, t in token_series_dict.items() if len(parse_numeric_series(t)) > 0]
    if not valid_companies:
        return None, None, None

    # Embedding matrix
    embeddings = np.array([embeddings_dict[c].detach().cpu().numpy() for c in valid_companies])
    emb_dist = squareform(pdist(embeddings, metric='cosine'))

    # DTW for raw series
    series_data = {c: parse_numeric_series(token_series_dict[c]) for c in valid_companies}
    ts_matrix = np.zeros((len(valid_companies), len(valid_companies)))
    for i, ci in enumerate(valid_companies):
        for j, cj in enumerate(valid_companies):
            if i < j:
                ts_matrix[i, j] = ts_matrix[j, i] = dtw.distance_fast(np.array(series_data[ci]), np.array(series_data[cj]))
    return emb_dist, ts_matrix, valid_companies

# === Cluster + ARI/NMI ===
def clustering_ari_analysis(embeddings_dict, token_series_dict, n_clusters=4):
    emb_dist, ts_dist, companies = compute_distance_matrices(embeddings_dict, token_series_dict)
    if emb_dist is None:
        raise ValueError("No valid companies with data.")

    # Hierarchical clustering in both spaces
    cluster_emb = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average')
    cluster_ts = AgglomerativeClustering(n_clusters=n_clusters, metric='precomputed', linkage='average')

    labels_emb = cluster_emb.fit_predict(emb_dist)
    labels_ts = cluster_ts.fit_predict(ts_dist)

    # ARI & NMI
    ari = adjusted_rand_score(labels_ts, labels_emb)
    nmi = normalized_mutual_info_score(labels_ts, labels_emb)

    return {
        "ARI": ari,
        "NMI": nmi,
        "labels_emb": labels_emb,
        "labels_ts": labels_ts,
        "companies": companies
    }

# Example: Run for RET
result = clustering_ari_analysis(ret_context_embeddings_from_model, ret_core, n_clusters=4)
print(f"Adjusted Rand Index (ARI): {result['ARI']:.3f}")
print(f"Normalized Mutual Info (NMI): {result['NMI']:.3f}")

import numpy as np
import torch
import torch.nn.functional as F
from scipy.stats import spearmanr, pearsonr

# === Inputs ===
# pooled_embeddings: [num_companies, 768] (from trained model)
# dtw_matrix: [num_companies, num_companies] (normalized DTW distances)
# block_names & block_sizes as before
block_sizes = [96,96,96,96,96,96,96,96]
block_names = ["RET", "SOC", "GOV", "ESG", "ESGFO", "ESGSO", "ENV", "SENTI"]

# === Step 1: Split embeddings into blocks ===
def get_blockwise_distances(embeddings, block_sizes, block_names):
    start = 0
    block_dists = {}
    for idx, size in enumerate(block_sizes):
        end = start + size
        block = embeddings[:, start:end]
        normed = F.normalize(block, p=2, dim=1)  # L2 norm
        cos_sim = normed @ normed.T
        cos_dist = 1 - cos_sim  # cosine distance
        block_dists[block_names[idx]] = cos_dist.detach().cpu().numpy()
        start = end
    return block_dists

# === Step 2: Compute blockwise cosine distance matrices ===
block_dists = get_blockwise_distances(pooled_embeddings, block_sizes, block_names)
dtw_np = dtw_matrix.detach().cpu().numpy()

# === Step 3: Correlations (Spearman & Pearson) ===
results = {}
triu_idx = np.triu_indices(len(dtw_np), k=1)  # use upper triangle (avoid duplicates)
dtw_flat = dtw_np[triu_idx]

for name, dist_mat in block_dists.items():
    embed_flat = dist_mat[triu_idx]
    spearman_corr = spearmanr(dtw_flat, embed_flat).correlation
    pearson_corr = pearsonr(dtw_flat, embed_flat)[0]
    results[name] = {"Spearman": spearman_corr, "Pearson": pearson_corr}

# === Print results ===
print("Correlation between DTW and Embedding Distances per Block:")
for name in block_names:
    print(f"{name:6} | Spearman: {results[name]['Spearman']:.3f} | Pearson: {results[name]['Pearson']:.3f}")



import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw

# === Block index mapping ===
BLOCKS = {
    "RET": (0, 96),
    "ESG": (288, 384)
}

def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

def find_top_dtw_peers(query_company, series_dict, top_k=5):
    """Find top-k peers by DTW distance."""
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        dist = dtw.distance(query_series, extract_series(series))
        distances[company] = dist
    return sorted(distances.items(), key=lambda x: x[1])[:top_k]

def compute_embedding_similarity(query_company, embedding_dict, block_name):
    """Compute cosine similarity between query company and all others (block-level)."""
    start, end = BLOCKS[block_name]
    query_emb = embedding_dict[query_company][start:end].unsqueeze(0)
    query_emb = F.normalize(query_emb, p=2, dim=1)
    similarities = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        comp_emb = F.normalize(emb[start:end].unsqueeze(0), p=2, dim=1)
        sim = torch.cosine_similarity(query_emb, comp_emb).item()
        similarities[company] = sim
    return similarities

def plot_trajectory_and_embedding_similarity(query_company,
                                             series_dict,
                                             embedding_dict,
                                             block_name="RET",
                                             top_k=5):
    # === Step 1: Get top peers by DTW ===
    top_peers = find_top_dtw_peers(query_company, series_dict, top_k=top_k)
    peer_names = [p[0] for p in top_peers]

    # === Step 2: Get embedding similarities ===
    emb_sims = compute_embedding_similarity(query_company, embedding_dict, block_name)
    emb_sims_selected = [emb_sims[name] for name in peer_names]

    # === Step 3: Plot time-series trajectories ===
    plt.figure(figsize=(12,5))
    query_series = extract_series(series_dict[query_company])
    plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3, color='black')
    for name, _ in top_peers:
        plt.plot(extract_series(series_dict[name]), label=name, linestyle='--')
    plt.title(f"{block_name} Trajectories (Top-{top_k} DTW Similar Peers)")
    plt.xlabel("Time")
    plt.ylabel(block_name)
    plt.legend()
    plt.grid(True)
    plt.show()

    # === Step 4: Plot embedding cosine similarities ===
    plt.figure(figsize=(8,4))
    plt.bar(peer_names, emb_sims_selected, color='skyblue')
    plt.title(f"{block_name} GPT-2 Embedding Cosine Similarity with {query_company}")
    plt.ylabel("Cosine Similarity")
    plt.ylim(0, 1)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.show()

plot_trajectory_and_embedding_similarity(
    query_company="AAPL",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=5
)

plot_trajectory_and_embedding_similarity(
    query_company="AAPL",
    series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model_12,
    block_name="ESG",
    top_k=5
)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw

# === Block index mapping ===
BLOCKS = {
    "RET": (0, 96),
    "ESG": (288, 384)
}

def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

def find_top_dtw_peers(query_company, series_dict, top_k=5, similar=True):
    """Find top-k peers by DTW distance (similar=True for closest, False for farthest)."""
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        dist = dtw.distance(query_series, extract_series(series))
        distances[company] = dist
    sorted_peers = sorted(distances.items(), key=lambda x: x[1], reverse=not similar)
    return sorted_peers[:top_k]

def compute_embedding_similarity(query_company, embedding_dict, block_name):
    """Compute cosine similarity between query company and all others (block-level)."""
    start, end = BLOCKS[block_name]
    query_emb = embedding_dict[query_company][start:end].unsqueeze(0)
    query_emb = F.normalize(query_emb, p=2, dim=1)
    similarities = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        comp_emb = F.normalize(emb[start:end].unsqueeze(0), p=2, dim=1)
        sim = torch.cosine_similarity(query_emb, comp_emb).item()
        similarities[company] = sim
    return similarities

def plot_trajectory_and_embedding_similarity(query_company,
                                             series_dict,
                                             embedding_dict,
                                             block_name="RET",
                                             top_k=5,
                                             similar=True):
    # === Step 1: Get peers by DTW (similar or dissimilar) ===
    top_peers = find_top_dtw_peers(query_company, series_dict, top_k=top_k, similar=similar)
    peer_names = [p[0] for p in top_peers]

    # === Step 2: Get embedding similarities ===
    emb_sims = compute_embedding_similarity(query_company, embedding_dict, block_name)
    emb_sims_selected = [emb_sims[name] for name in peer_names]

    # === Step 3: Plot time-series trajectories ===
    plt.figure(figsize=(5,5))
    query_series = extract_series(series_dict[query_company])
    plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3, color='black')
    for name, _ in top_peers:
        plt.plot(extract_series(series_dict[name]), label=name, linestyle='--')
    sim_text = "Similar" if similar else "Dissimilar"
    plt.title(f"{block_name} Trajectories (Top-{top_k} DTW {sim_text} Peers)")
    plt.xlabel("Time")
    plt.ylabel(block_name)
    plt.legend()
    plt.grid(True)
    plt.show()

    # === Step 4: Plot embedding cosine similarities ===
    plt.figure(figsize=(5,5))
    plt.bar(peer_names, emb_sims_selected, color='salmon' if not similar else 'skyblue')
    plt.title(f"{block_name} GPT-2 Embedding Cosine Similarity with {query_company} ({sim_text})")
    plt.ylabel("Cosine Similarity")
    plt.ylim(0, 1)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.show()

plot_trajectory_and_embedding_similarity("AAPL", ret_core, ret_context_embeddings_from_model_12, block_name="RET", top_k=5, similar=True)

plot_trajectory_and_embedding_similarity("AAPL", ret_core, ret_context_embeddings_from_model_12, block_name="RET", top_k=5, similar=False)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw

# === Block index mapping ===
BLOCKS = {
    "RET": (0, 96),
    "ESG": (288, 384)
}

def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

def find_top_dtw_peers(query_company, series_dict, top_k=5, similar=True):
    """Find top-k peers by DTW distance (similar=True for closest, False for farthest)."""
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        dist = dtw.distance(query_series, extract_series(series))
        distances[company] = dist
    sorted_peers = sorted(distances.items(), key=lambda x: x[1], reverse=not similar)
    return sorted_peers[:top_k]

def compute_embedding_similarity(query_company, embedding_dict, block_name):
    """Compute cosine similarity between query company and all others (block-level)."""
    start, end = BLOCKS[block_name]
    query_emb = embedding_dict[query_company][start:end].unsqueeze(0)
    query_emb = F.normalize(query_emb, p=2, dim=1)
    similarities = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        comp_emb = F.normalize(emb[start:end].unsqueeze(0), p=2, dim=1)
        sim = torch.cosine_similarity(query_emb, comp_emb).item()
        similarities[company] = sim
    return similarities

def plot_trajectory_and_embedding_similarity(query_company,
                                             series_dict,
                                             embedding_dict,
                                             block_name="RET",
                                             top_k=5,
                                             similar=True):
    # Restrict to companies with both series & embeddings
    valid_companies = set(series_dict.keys()) & set(embedding_dict.keys())
    if query_company not in valid_companies:
        print(f"{query_company} not found in both series and embeddings.")
        return
    filtered_series_dict = {k: series_dict[k] for k in valid_companies}

    # === Step 1: Get peers by DTW (similar or dissimilar) ===
    top_peers = find_top_dtw_peers(query_company, filtered_series_dict, top_k=top_k, similar=similar)
    peer_names = [p[0] for p in top_peers]
    if not peer_names:
        print(f"No valid peers with embeddings found for {query_company}.")
        return

    # === Step 2: Get embedding similarities ===
    emb_sims = compute_embedding_similarity(query_company, embedding_dict, block_name)
    emb_sims_selected = [emb_sims[name] for name in peer_names if name in emb_sims]

    # === Step 3: Plot time-series trajectories ===
    plt.figure(figsize=(12,5))
    query_series = extract_series(series_dict[query_company])
    plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3, color='black')
    for name in peer_names:
        plt.plot(extract_series(series_dict[name]), label=name, linestyle='--')
    sim_text = "Similar" if similar else "Dissimilar"
    plt.title(f"{block_name} Trajectories (Top-{top_k} DTW {sim_text} Peers)")
    plt.xlabel("Time")
    plt.ylabel(block_name)
    plt.legend()
    plt.grid(True)
    plt.show()

    # === Step 4: Plot embedding cosine similarities ===
    plt.figure(figsize=(8,4))
    plt.bar(peer_names, emb_sims_selected, color='salmon' if not similar else 'skyblue')
    plt.title(f"{block_name} GPT-2 Embedding Cosine Similarity with {query_company} ({sim_text})")
    plt.ylabel("Cosine Similarity")
    plt.ylim(0, 1)
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.show()

import matplotlib.pyplot as plt
import numpy as np

def plot_trajectory_and_embedding_similarity(query_company,
                                             series_dict,
                                             embedding_dict,
                                             block_name="RET",
                                             top_k=5,
                                             similar=True):
    # Restrict to companies with both series & embeddings
    valid_companies = set(series_dict.keys()) & set(embedding_dict.keys())
    if query_company not in valid_companies:
        print(f"{query_company} not found in both series and embeddings.")
        return
    filtered_series_dict = {k: series_dict[k] for k in valid_companies}

    # === Step 1: Get peers by DTW (similar or dissimilar) ===
    top_peers = find_top_dtw_peers(query_company, filtered_series_dict, top_k=top_k, similar=similar)
    peer_names = [p[0] for p in top_peers]
    if not peer_names:
        print(f"No valid peers with embeddings found for {query_company}.")
        return

    # === Step 2: Get embedding similarities ===
    emb_sims = compute_embedding_similarity(query_company, embedding_dict, block_name)
    emb_sims_selected = [emb_sims[name] for name in peer_names if name in emb_sims]

    # === Gaudy color palette ===
    gaudy_colors = ['magenta', 'lime', 'cyan', 'orange', 'red', 'blue', 'gold', 'purple']
    sim_text = "Similar" if similar else "Dissimilar"

    # === Step 3: Plot time-series trajectories ===
    plt.figure(figsize=(5,4))
    query_series = extract_series(series_dict[query_company])
    plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3, color='black')
    for idx, name in enumerate(peer_names):
        plt.plot(extract_series(series_dict[name]), label=name, linestyle='--',
                 linewidth=2.5, color=gaudy_colors[idx % len(gaudy_colors)])
    plt.title(f"{block_name} Trajectories (Top-{top_k} DTW {sim_text} Peers)", fontsize=12, fontweight='bold')
    plt.xlabel("Time", fontsize=12, fontweight='bold')
    plt.ylabel(block_name, fontsize=12, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.xticks(fontsize=12, fontweight='bold')
    plt.yticks(fontsize=12, fontweight='bold')
    plt.show()

    # === Step 4: Plot embedding cosine similarities ===
    plt.figure(figsize=(5,4))
    plt.bar(peer_names, emb_sims_selected, color='red' if not similar else 'green', edgecolor='black')
    plt.title(f"{block_name} GPT-2 Embedding Cosine Similarity with {query_company} ({sim_text})", fontsize=12, fontweight='bold')
    plt.ylabel("Cosine Similarity", fontsize=12, fontweight='bold')
    plt.ylim(0, 1)
    plt.xticks(fontsize=12, fontweight='bold')
    plt.yticks(fontsize=12, fontweight='bold')
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.show()

plot_trajectory_and_embedding_similarity("TSLA", ret_core, ret_context_embeddings_from_model_12, block_name="RET", top_k=3, similar=False)

plot_trajectory_and_embedding_similarity("TSLA", ret_core, ret_context_embeddings_from_model_12, block_name="RET", top_k=3)

plot_trajectory_and_embedding_similarity("CSCO", ret_core, ret_context_embeddings_from_model, block_name="RET", top_k=3)

plot_trajectory_and_embedding_similarity("CSCO", ret_core, ret_context_embeddings_from_model, block_name="RET", top_k=3,similar=False)

plot_trajectory_and_embedding_similarity("VRT", ret_core, ret_context_embeddings_from_model, block_name="RET", top_k=3)

plot_trajectory_and_embedding_similarity("VRT", ret_core, ret_context_embeddings_from_model_12, block_name="RET", top_k=3,similar=False)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw
from sklearn.preprocessing import StandardScaler

# === Block index mapping ===
BLOCKS = {
    "RET": (0, 96),
    "ESG": (288, 384)
}

def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

# === Preprocess embeddings: Center & Whiten for better spread ===
def preprocess_embeddings_blockwise(embedding_dict, block_name):
    """Center & whiten block embeddings to improve similarity spread."""
    start, end = BLOCKS[block_name]
    companies = list(embedding_dict.keys())
    embeddings = np.array([
        embedding_dict[c].cpu().numpy()[start:end] if hasattr(embedding_dict[c], "cpu") else embedding_dict[c][start:end]
        for c in companies
    ])
    # Center & whiten
    scaler = StandardScaler()
    processed = scaler.fit_transform(embeddings)
    return dict(zip(companies, processed))

def find_top_dtw_peers(query_company, series_dict, top_k=5, similar=True):
    """Find top-k peers by DTW distance (similar=True for closest, False for farthest)."""
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        try:
            dist = dtw.distance(query_series, extract_series(series))
            distances[company] = dist
        except Exception:
            continue  # skip problematic series
    sorted_peers = sorted(distances.items(), key=lambda x: x[1], reverse=not similar)
    return sorted_peers[:top_k]

def compute_embedding_distance(query_company, embedding_dict, block_name):
    """Compute cosine distance (1 - similarity) between query and others (on processed embeddings)."""
    if query_company not in embedding_dict:
        raise ValueError(f"{query_company} not found in embedding dictionary.")
    query_emb = torch.tensor(embedding_dict[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)
    distances = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        try:
            comp_emb = torch.tensor(emb).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            dist = 1 - torch.cosine_similarity(query_emb, comp_emb).item()  # 0=identical, 1=very different
            distances[company] = dist
        except Exception:
            continue
    return distances

def plot_trajectory_and_embedding_distance(query_company,
                                           series_dict,
                                           embedding_dict,
                                           block_name="RET",
                                           top_k=5,
                                           similar=True):
    try:
        # === Step 1: Preprocess embeddings (center + whiten for better spread) ===
        processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)

        # === Step 2: Get peers by DTW ===
        top_peers = find_top_dtw_peers(query_company, series_dict, top_k=top_k, similar=similar)
        if not top_peers:
            print(f"No valid peers found for {query_company} in {block_name}.")
            return
        peer_names = [p[0] for p in top_peers]

        # === Step 3: Get embedding distances ===
        emb_dists = compute_embedding_distance(query_company, processed_embeddings, block_name)
        emb_dists_selected = [emb_dists.get(name, 1.0) for name in peer_names]  # Default max distance if missing

        # === Step 4: Plot time-series trajectories ===
        plt.figure(figsize=(12,5))
        query_series = extract_series(series_dict[query_company])
        plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3, color='black')
        for name, _ in top_peers:
            try:
                plt.plot(extract_series(series_dict[name]), label=name, linestyle='--')
            except Exception:
                continue
        sim_text = "Similar" if similar else "Dissimilar"
        plt.title(f"{block_name} Trajectories (Top-{top_k} DTW {sim_text} Peers)")
        plt.xlabel("Time")
        plt.ylabel(block_name)
        plt.legend()
        plt.grid(True)
        plt.show()

        # === Step 5: Plot embedding distances ===
        plt.figure(figsize=(8,4))
        plt.bar(peer_names, emb_dists_selected, color='salmon' if not similar else 'skyblue')
        plt.title(f"{block_name} GPT-2 Embedding Cosine Distance with {query_company} ({sim_text})")
        plt.ylabel("Cosine Distance (0 = identical, 1 = very different)")
        plt.ylim(0, 1)
        plt.grid(axis='y', linestyle='--', alpha=0.6)
        plt.show()

    except Exception as e:
        print(f"Skipping {query_company} for {block_name} due to error: {e}")

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="AAPL",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=5,
    similar=True
)

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="TSLA",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=5,
    similar=True
)

BLOCKS = {
    "RET": (0, 96),
    "SOC": (96, 192),
    "GOV": (192, 288),
    "ESG": (288, 384),
    "ESGFO": (384, 480),
    "ESGSO": (480, 576),
    "ENV": (576, 672),
    "SENTI": (672, 768)
}

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw
from sklearn.preprocessing import StandardScaler




BLOCKS = {
    "RET": (0, 96),
    "SOC": (96, 192),
    "GOV": (192, 288),
    "ESG": (288, 384),
    "ESGFO": (384, 480),
    "ESGSO": (480, 576),
    "ENV": (576, 672),
    "SENTI": (672, 768)
}
def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

# === Preprocess embeddings: Center & Whiten for better spread ===
def preprocess_embeddings_blockwise(embedding_dict, block_name):
    start, end = BLOCKS[block_name]
    companies = list(embedding_dict.keys())
    embeddings = np.array([
        embedding_dict[c].cpu().numpy()[start:end] if hasattr(embedding_dict[c], "cpu") else embedding_dict[c][start:end]
        for c in companies
    ])
    scaler = StandardScaler()
    processed = scaler.fit_transform(embeddings)
    return dict(zip(companies, processed))

def find_top_dtw_peers(query_company, series_dict, top_k=5, similar=True):
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        try:
            dist = dtw.distance(query_series, extract_series(series))
            distances[company] = dist
        except Exception:
            continue
    sorted_peers = sorted(distances.items(), key=lambda x: x[1], reverse=not similar)
    return sorted_peers[:top_k]

def compute_embedding_distance(query_company, embedding_dict, block_name):
    if query_company not in embedding_dict:
        raise ValueError(f"{query_company} not found in embedding dictionary.")
    query_emb = torch.tensor(embedding_dict[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)
    distances = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        try:
            comp_emb = torch.tensor(emb).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            dist = 1 - torch.cosine_similarity(query_emb, comp_emb).item()
            distances[company] = dist
        except Exception:
            continue
    return distances

def plot_trajectory_and_embedding_distance(query_company,
                                           series_dict,
                                           embedding_dict,
                                           block_name="RET",
                                           top_k=5,
                                           similar=True):
    try:
        # === Step 1: Preprocess embeddings ===
        processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)

        # === Step 2: Get peers by DTW ===
        top_peers = find_top_dtw_peers(query_company, series_dict, top_k=top_k, similar=similar)
        if not top_peers:
            print(f"No valid peers found for {query_company} in {block_name}.")
            return
        peer_names = [p[0] for p in top_peers]

        # === Step 3: Get embedding distances ===
        emb_dists = compute_embedding_distance(query_company, processed_embeddings, block_name)
        emb_dists_selected = [emb_dists.get(name, 1.0) for name in peer_names]

        # === Colors & Style ===
        gaudy_colors = ['magenta', 'lime', 'cyan', 'orange', 'red', 'blue', 'gold', 'purple']
        sim_text = "Similar" if similar else "Dissimilar"

        # === Step 4: Plot time-series trajectories ===
        plt.figure(figsize=(5,5))
        query_series = extract_series(series_dict[query_company])
        plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3.5, color='black')
        for idx, (name, _) in enumerate(top_peers):
            try:
                plt.plot(extract_series(series_dict[name]),
                         label=name,
                         linestyle='--',
                         linewidth=3,
                         color=gaudy_colors[idx % len(gaudy_colors)])
            except Exception:
                continue
        plt.title(f"{block_name} Trajectories (Top-{top_k} ", fontsize=12, fontweight='bold')
        plt.xlabel("Time", fontsize=12, fontweight='bold')
        plt.ylabel(block_name, fontsize=12, fontweight='bold')
        plt.legend(fontsize=12)
        plt.grid(True)
        plt.xticks(fontsize=12, fontweight='bold')
        plt.yticks(fontsize=12, fontweight='bold')
        plt.show()

        # === Step 5: Plot embedding distances ===
        plt.figure(figsize=(5,5))
        bar_colors = [gaudy_colors[idx % len(gaudy_colors)] for idx in range(len(peer_names))]
        plt.bar(peer_names,
                emb_dists_selected,
                color=bar_colors,
                edgecolor='black',
                linewidth=1.2)
        plt.title(f"{block_name} GPT-2 Embedding Cosine Distance",
                  fontsize=12, fontweight='bold')
        plt.ylabel("Cosine Distance", fontsize=12, fontweight='bold')
        plt.ylim(0, 1)
        plt.xticks(fontsize=12, fontweight='bold', rotation=15)
        plt.yticks(fontsize=12, fontweight='bold')
        plt.grid(axis='y', linestyle='--', alpha=0.6)
        plt.show()

    except Exception as e:
        print(f"Skipping {query_company} for {block_name} due to error: {e}")

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="TSLA",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=5,
    similar=True
)

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="CSCO",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=2,
    similar=True
)

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="DE",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=8,
    similar=True
)

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="REG",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=8,
    similar=True
)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw
from sklearn.preprocessing import StandardScaler
import pandas as pd

# === Block index mapping ===

BLOCKS = {
    "RET": (0, 96),
    "SOC": (96, 192),
    "GOV": (192, 288),
    "ESG": (288, 384),
    "ESGFO": (384, 480),
    "ESGSO": (480, 576),
    "ENV": (576, 672),
    "SENTI": (672, 768)
}

def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    return [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]

# === Preprocess embeddings: Center & Whiten for better spread ===
def preprocess_embeddings_blockwise(embedding_dict, block_name):
    start, end = BLOCKS[block_name]
    companies = list(embedding_dict.keys())
    embeddings = np.array([
        embedding_dict[c].cpu().numpy()[start:end] if hasattr(embedding_dict[c], "cpu") else embedding_dict[c][start:end]
        for c in companies
    ])
    scaler = StandardScaler()
    processed = scaler.fit_transform(embeddings)
    return dict(zip(companies, processed))

def find_top_dtw_peers(query_company, series_dict, top_k=5, similar=True):
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        try:
            dist = dtw.distance(query_series, extract_series(series))
            distances[company] = dist
        except Exception:
            continue
    sorted_peers = sorted(distances.items(), key=lambda x: x[1], reverse=not similar)
    return sorted_peers[:top_k], distances  # return all distances too

def compute_embedding_distance(query_company, embedding_dict, block_name):
    if query_company not in embedding_dict:
        raise ValueError(f"{query_company} not found in embedding dictionary.")
    query_emb = torch.tensor(embedding_dict[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)
    distances = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        try:
            comp_emb = torch.tensor(emb).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            dist = 1 - torch.cosine_similarity(query_emb, comp_emb).item()
            distances[company] = dist
        except Exception:
            continue
    return distances

def plot_trajectory_and_embedding_distance(query_company,
                                           series_dict,
                                           embedding_dict,
                                           block_name="RET",
                                           top_k=5,
                                           similar=True):
    try:
        # === Step 1: Preprocess embeddings ===
        processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)

        # === Step 2: Get peers by DTW ===
        top_peers, all_dtw_distances = find_top_dtw_peers(query_company, series_dict, top_k=top_k, similar=similar)
        if not top_peers:
            print(f"No valid peers found for {query_company} in {block_name}.")
            return
        peer_names = [p[0] for p in top_peers]

        # === Step 3: Get embedding distances ===
        emb_dists = compute_embedding_distance(query_company, processed_embeddings, block_name)
        emb_dists_selected = [emb_dists.get(name, 1.0) for name in peer_names]

        # === Step 4: Print DTW & Similarity values ===
        dtw_selected = [all_dtw_distances.get(name, None) for name in peer_names]
        cosine_sims = [1 - d for d in emb_dists_selected]  # convert distance back to similarity

        df = pd.DataFrame({
            "Company": peer_names,
            "DTW_Distance": dtw_selected,
            "Cosine_Similarity": cosine_sims
        })
        print("\n=== Peer Distance & Similarity Metrics ===")
        print(df.to_string(index=False, float_format="%.4f"))

        # === Colors & Style ===
        gaudy_colors = ['magenta', 'lime', 'cyan', 'orange', 'red', 'blue', 'gold', 'purple']
        sim_text = "Similar" if similar else "Dissimilar"

        # === Step 5: Plot time-series trajectories ===
        plt.figure(figsize=(5,5))
        query_series = extract_series(series_dict[query_company])
        plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3.5, color='black')
        for idx, (name, _) in enumerate(top_peers):
            try:
                plt.plot(extract_series(series_dict[name]),
                         label=name,
                         linestyle='--',
                         linewidth=3,
                         color=gaudy_colors[idx % len(gaudy_colors)])
            except Exception:
                continue
        plt.title(f"{block_name} Trajectories (Top-{top_k} {sim_text})", fontsize=12, fontweight='bold')
        plt.xlabel("Time", fontsize=12, fontweight='bold')
        plt.ylabel(block_name, fontsize=12, fontweight='bold')
        plt.legend(fontsize=12)
        plt.grid(True)
        plt.xticks(fontsize=12, fontweight='bold')
        plt.yticks(fontsize=12, fontweight='bold')
        plt.show()

        # === Step 6: Plot embedding distances ===
        plt.figure(figsize=(5,5))
        bar_colors = [gaudy_colors[idx % len(gaudy_colors)] for idx in range(len(peer_names))]
        plt.bar(peer_names,
                emb_dists_selected,
                color=bar_colors,
                edgecolor='black',
                linewidth=1.2)
        plt.title(f"{block_name} GPT-2 Embedding Cosine Distance",
                  fontsize=12, fontweight='bold')
        plt.ylabel("Cosine Distance", fontsize=12, fontweight='bold')
        plt.ylim(0, 1)
        plt.xticks(fontsize=12, fontweight='bold', rotation=15)
        plt.yticks(fontsize=12, fontweight='bold')
        plt.grid(axis='y', linestyle='--', alpha=0.6)
        plt.show()

    except Exception as e:
        print(f"Skipping {query_company} for {block_name} due to error: {e}")

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="REG",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=20,
    similar=False
)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw
from sklearn.preprocessing import StandardScaler
import pandas as pd

# === Block index mapping ===

BLOCKS = {
    "RET": (0, 96),
    "SOC": (96, 192),
    "GOV": (192, 288),
    "ESG": (288, 384),
    "ESGFO": (384, 480),
    "ESGSO": (480, 576),
    "ENV": (576, 672),
    "SENTI": (672, 768)
}

def extract_series(token_str):
    """Extract numeric series from token string like '<RET_0.10> <RET_-0.03>'"""
    import re
    values = [float(x) for x in re.findall(r"[-+]?\d+(?:\.\d+)?", token_str)]
    return values if len(values) > 0 else None

# === Preprocess embeddings: Center & Whiten for better spread ===
def preprocess_embeddings_blockwise(embedding_dict, block_name):
    start, end = BLOCKS[block_name]
    companies = list(embedding_dict.keys())
    embeddings = np.array([
        embedding_dict[c].cpu().numpy()[start:end] if hasattr(embedding_dict[c], "cpu") else embedding_dict[c][start:end]
        for c in companies
    ])
    scaler = StandardScaler()
    processed = scaler.fit_transform(embeddings)
    return dict(zip(companies, processed))

def find_top_dtw_peers(query_company, series_dict, top_k=5, similar=True):
    if query_company not in series_dict:
        raise ValueError(f"{query_company} not found in series.")
    query_series = extract_series(series_dict[query_company])
    if query_series is None:
        raise ValueError(f"No valid series for {query_company}.")
    distances = {}
    for company, series in series_dict.items():
        if company == query_company:
            continue
        comp_series = extract_series(series)
        if comp_series is None:  # skip invalid series
            continue
        try:
            dist = dtw.distance(query_series, comp_series)
            if np.isinf(dist):  # skip infinite DTW
                continue
            distances[company] = dist
        except Exception:
            continue
    sorted_peers = sorted(distances.items(), key=lambda x: x[1], reverse=not similar)
    return sorted_peers[:top_k], distances

def compute_embedding_distance(query_company, embedding_dict, block_name):
    if query_company not in embedding_dict:
        raise ValueError(f"{query_company} not found in embedding dictionary.")
    query_emb = torch.tensor(embedding_dict[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)
    distances = {}
    for company, emb in embedding_dict.items():
        if company == query_company:
            continue
        if np.allclose(emb, 0):  # skip zero embeddings
            continue
        try:
            comp_emb = torch.tensor(emb).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            dist = 1 - torch.cosine_similarity(query_emb, comp_emb).item()
            distances[company] = dist
        except Exception:
            continue
    return distances

def plot_trajectory_and_embedding_distance(query_company,
                                           series_dict,
                                           embedding_dict,
                                           block_name="RET",
                                           top_k=5,
                                           similar=True):
    try:
        # === Step 1: Preprocess embeddings ===
        processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)

        # === Step 2: Get peers by DTW ===
        top_peers, all_dtw_distances = find_top_dtw_peers(query_company, series_dict, top_k=top_k, similar=similar)
        if not top_peers:
            print(f"No valid peers found for {query_company} in {block_name}.")
            return
        peer_names = [p[0] for p in top_peers]

        # === Step 3: Get embedding distances (filtering out missing ones) ===
        emb_dists = compute_embedding_distance(query_company, processed_embeddings, block_name)
        emb_dists_selected = [emb_dists.get(name, np.nan) for name in peer_names]
        dtw_selected = [all_dtw_distances.get(name, np.nan) for name in peer_names]
        cosine_sims = [1 - d if not np.isnan(d) else np.nan for d in emb_dists_selected]

        # === Filter out any rows with NaN ===
        df = pd.DataFrame({
            "Company": peer_names,
            "DTW_Distance": dtw_selected,
            "Cosine_Similarity": cosine_sims
        }).dropna()

        if df.empty:
            print(f"No valid data to display for {query_company}.")
            return

        print("\n=== Peer Distance & Similarity Metrics ===")
        print(df.to_string(index=False, float_format="%.4f"))

        # === Colors & Style ===
        gaudy_colors = ['magenta', 'lime', 'cyan', 'orange', 'red', 'blue', 'gold', 'purple']
        sim_text = "Similar" if similar else "Dissimilar"

        # === Step 4: Plot time-series trajectories ===
        plt.figure(figsize=(5,5))
        query_series = extract_series(series_dict[query_company])
        plt.plot(query_series, label=f"{query_company} (Query)", linewidth=3.5, color='black')
        for idx, (name, _) in enumerate(top_peers):
            if name not in df["Company"].values:
                continue
            try:
                plt.plot(extract_series(series_dict[name]),
                         label=name,
                         linestyle='--',
                         linewidth=3,
                         color=gaudy_colors[idx % len(gaudy_colors)])
            except Exception:
                continue
        plt.title(f"{block_name} Trajectories (Top-{top_k} {sim_text})", fontsize=12, fontweight='bold')
        plt.xlabel("Time", fontsize=12, fontweight='bold')
        plt.ylabel(block_name, fontsize=12, fontweight='bold')
        plt.legend(fontsize=12)
        plt.grid(True)
        plt.xticks(fontsize=12, fontweight='bold')
        plt.yticks(fontsize=12, fontweight='bold')
        plt.show()

        # === Step 5: Plot embedding distances ===
        plt.figure(figsize=(5,5))
        bar_colors = [gaudy_colors[idx % len(gaudy_colors)] for idx in range(len(df))]
        plt.bar(df["Company"],
                1 - df["Cosine_Similarity"],  # plot distances
                color=bar_colors,
                edgecolor='black',
                linewidth=1.2)
        plt.title(f"{block_name} GPT-2 Embedding Cosine Distance",
                  fontsize=12, fontweight='bold')
        plt.ylabel("Cosine Distance", fontsize=12, fontweight='bold')
        plt.ylim(0, 1)
        plt.xticks(fontsize=12, fontweight='bold', rotation=15)
        plt.yticks(fontsize=12, fontweight='bold')
        plt.grid(axis='y', linestyle='--', alpha=0.6)
        plt.show()

    except Exception as e:
        print(f"Skipping {query_company} for {block_name} due to error: {e}")

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="REG",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=20,
    similar=False
)

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="REG",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=20,
    similar=True
)

# Example for RET block
plot_trajectory_and_embedding_distance(
    query_company="REG",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",
    top_k=20,
    similar=False
)

from scipy.stats import spearmanr, pearsonr

def company_dtw_vs_cosine_plot(query_company,
                               series_dict,
                               embedding_dict,
                               block_name="RET"):
    # === Step 1: Preprocess embeddings ===
    processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)

    if query_company not in series_dict or query_company not in processed_embeddings:
        print(f"{query_company} not found in series or embedding dictionary.")
        return

    query_series = extract_series(series_dict[query_company])
    if query_series is None:
        print(f"No valid series for {query_company}")
        return

    # Query embedding
    query_emb = torch.tensor(processed_embeddings[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)

    dtw_vals, cos_sims, companies = [], [], []

    # === Step 2: Compare to all others ===
    for company, series in series_dict.items():
        if company == query_company or company not in processed_embeddings:
            continue
        comp_series = extract_series(series)
        if comp_series is None:
            continue
        try:
            dist = dtw.distance(query_series, comp_series)
            if np.isinf(dist):
                continue
            # Cosine similarity
            comp_emb = torch.tensor(processed_embeddings[company]).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            sim = torch.cosine_similarity(query_emb, comp_emb).item()
            dtw_vals.append(dist)
            cos_sims.append(sim)
            companies.append(company)
        except:
            continue

    if not dtw_vals:
        print(f"No valid comparisons for {query_company}")
        return

    # === Step 3: Correlation ===
    rho, _ = spearmanr(dtw_vals, cos_sims)
    r, _ = pearsonr(dtw_vals, cos_sims)
    print(f"{query_company}: Spearman ρ = {rho:.3f}, Pearson r = {r:.3f}")

    # === Step 4: Scatter plot ===
    plt.figure(figsize=(7,6))
    plt.scatter(dtw_vals, cos_sims, s=40, alpha=0.7, color="blue", edgecolor="black")
    plt.title(f"{block_name}: {query_company} vs All\nDTW vs Cosine Similarity (ρ={rho:.2f})",
              fontsize=14, fontweight="bold")
    plt.xlabel("DTW Distance (Trajectory Dissimilarity)", fontsize=12, fontweight="bold")
    plt.ylabel("Cosine Similarity (Embedding Proximity)", fontsize=12, fontweight="bold")
    plt.grid(alpha=0.4)
    plt.show()

    # === Step 5: Return as DataFrame (optional) ===
    df = pd.DataFrame({
        "Company": companies,
        "DTW_Distance": dtw_vals,
        "Cosine_Similarity": cos_sims
    }).sort_values(by="DTW_Distance")
    return df, rho, r

df, spearman_corr, pearson_corr = company_dtw_vs_cosine_plot(
    query_company="AAPL",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model,
    block_name="RET"
)
print(df.head(10))

from scipy.stats import spearmanr, pearsonr
from sklearn.preprocessing import MinMaxScaler

def company_dtw_vs_cosine_plot_normalized(query_company,
                                          series_dict,
                                          embedding_dict,
                                          block_name="RET"):
    # === Step 1: Preprocess embeddings ===
    processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)

    if query_company not in series_dict or query_company not in processed_embeddings:
        print(f"{query_company} not found in series or embedding dictionary.")
        return

    query_series = extract_series(series_dict[query_company])
    if query_series is None:
        print(f"No valid series for {query_company}")
        return

    # Query embedding
    query_emb = torch.tensor(processed_embeddings[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)

    dtw_vals, cos_dists, companies = [], [], []

    # === Step 2: Compare to all others ===
    for company, series in series_dict.items():
        if company == query_company or company not in processed_embeddings:
            continue
        comp_series = extract_series(series)
        if comp_series is None:
            continue
        try:
            dist = dtw.distance(query_series, comp_series)
            if np.isinf(dist):
                continue
            # Cosine distance (1 - similarity)
            comp_emb = torch.tensor(processed_embeddings[company]).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            sim = torch.cosine_similarity(query_emb, comp_emb).item()
            cos_dist = 1 - sim
            dtw_vals.append(dist)
            cos_dists.append(cos_dist)
            companies.append(company)
        except:
            continue

    if not dtw_vals:
        print(f"No valid comparisons for {query_company}")
        return

    # === Step 3: Normalize both distances to [0,1] ===
    scaler = MinMaxScaler()
    dtw_vals_norm = scaler.fit_transform(np.array(dtw_vals).reshape(-1,1)).flatten()
    cos_dists_norm = scaler.fit_transform(np.array(cos_dists).reshape(-1,1)).flatten()

    # === Step 4: Correlation ===
    rho, _ = spearmanr(dtw_vals_norm, cos_dists_norm)
    r, _ = pearsonr(dtw_vals_norm, cos_dists_norm)
    print(f"{query_company}: Spearman ρ = {rho:.3f}, Pearson r = {r:.3f}")

    # === Step 5: Scatter plot ===
    plt.figure(figsize=(7,6))
    plt.scatter(dtw_vals_norm, cos_dists_norm, s=40, alpha=0.7, color="blue", edgecolor="black")
    plt.title(f"{block_name}: {query_company} vs All\nNormalized DTW vs Cosine Distance (ρ={rho:.2f})",
              fontsize=14, fontweight="bold")
    plt.xlabel("Normalized DTW Distance", fontsize=12, fontweight="bold")
    plt.ylabel("Normalized Cosine Distance", fontsize=12, fontweight="bold")
    plt.grid(alpha=0.4)
    plt.show()

    # === Step 6: Return DataFrame ===
    df = pd.DataFrame({
        "Company": companies,
        "DTW_Distance_Norm": dtw_vals_norm,
        "Cosine_Distance_Norm": cos_dists_norm
    }).sort_values(by="DTW_Distance_Norm")
    return df, rho, r

df, spearman_corr, pearson_corr = company_dtw_vs_cosine_plot_normalized(
    query_company="AAPL",
    series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model,
    block_name="RET"
)
print(df)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pandas as pd

def company_similarity_band_barplot(query_company,
                                    series_dict,
                                    embedding_dict,
                                    block_name="RET",
                                    num_bins=5):
    # === Step 1: Preprocess embeddings ===
    processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)
    if query_company not in series_dict or query_company not in processed_embeddings:
        print(f"{query_company} not found in series or embedding dictionary.")
        return

    query_series = extract_series(series_dict[query_company])
    if query_series is None:
        print(f"No valid series for {query_company}")
        return

    query_emb = torch.tensor(processed_embeddings[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)

    dtw_vals, cos_sims = [], []

    # === Step 2: Compare to all others ===
    for company, series in series_dict.items():
        if company == query_company or company not in processed_embeddings:
            continue
        comp_series = extract_series(series)
        if comp_series is None:
            continue
        try:
            dist = dtw.distance(query_series, comp_series)
            if np.isinf(dist):
                continue
            comp_emb = torch.tensor(processed_embeddings[company]).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            sim = torch.cosine_similarity(query_emb, comp_emb).item()
            dtw_vals.append(dist)
            cos_sims.append(sim)
        except:
            continue

    if not dtw_vals:
        print(f"No valid comparisons for {query_company}")
        return

    # === Step 3: Normalize DTW (optional for scale) ===
    dtw_vals = MinMaxScaler().fit_transform(np.array(dtw_vals).reshape(-1,1)).flatten()

    # === Step 4: Bin by Cosine Similarity ===
    cos_sims = np.array(cos_sims)
    bins = np.linspace(0, 1, num_bins + 1)  # e.g., 0–0.2, 0.2–0.4, ...
    bin_indices = np.digitize(cos_sims, bins) - 1

    band_means = []
    band_labels = []
    for i in range(num_bins):
        mask = bin_indices == i
        if np.sum(mask) > 0:
            band_means.append(np.mean(np.array(dtw_vals)[mask]))
            band_labels.append(f"{bins[i]:.1f}–{bins[i+1]:.1f}")
        else:
            band_means.append(0)
            band_labels.append(f"{bins[i]:.1f}–{bins[i+1]:.1f}")

    # === Step 5: Bar Plot ===
    plt.figure(figsize=(8,6))
    plt.bar(band_labels, band_means, color="skyblue", edgecolor="black")
    plt.title(f"{block_name}: {query_company}\nDTW Distance by Cosine Similarity Band",
              fontsize=14, fontweight="bold")
    plt.xlabel("Cosine Similarity Band", fontsize=12, fontweight="bold")
    plt.ylabel("Average Normalized DTW Distance", fontsize=12, fontweight="bold")
    plt.grid(axis="y", linestyle="--", alpha=0.6)
    plt.show()

    # === Return as DataFrame ===
    df = pd.DataFrame({
        "Cosine_Band": band_labels,
        "Avg_DTW_Distance": band_means
    })
    return df

"""```
#REG```

#REG
"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="REG",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

df_bands = company_similarity_band_barplot_n(
    query_company="REG",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

"""#TSLA"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="TSLA",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="TSLA",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model_12,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="WBA",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

"""#MGM"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="MGM",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="MGM",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

"""#ALB"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="ALB",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="ALB",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

"""#IVZ"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="IVZ",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="IVZ",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

"""#HRL"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HRL",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HRL",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

"""#NWSA"""

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="NWSA",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="NWSA",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from dtaidistance import dtw
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import pandas as pd

def company_similarity_band_barplot_n(query_company,
                                    series_dict,
                                    embedding_dict,
                                    block_name="RET",
                                    num_bins=5):
    # === Step 1: Preprocess embeddings ===
    processed_embeddings = preprocess_embeddings_blockwise(embedding_dict, block_name)
    if query_company not in series_dict or query_company not in processed_embeddings:
        print(f"{query_company} not found in series or embedding dictionary.")
        return

    query_series = extract_series(series_dict[query_company])
    if query_series is None:
        print(f"No valid series for {query_company}")
        return

    query_emb = torch.tensor(processed_embeddings[query_company]).unsqueeze(0).float()
    query_emb = F.normalize(query_emb, p=2, dim=1)

    dtw_vals, cos_sims = [], []

    # === Step 2: Compare to all others ===
    for company, series in series_dict.items():
        if company == query_company or company not in processed_embeddings:
            continue
        comp_series = extract_series(series)
        if comp_series is None:
            continue
        try:
            dist = dtw.distance(query_series, comp_series)
            if np.isinf(dist):
                continue
            comp_emb = torch.tensor(processed_embeddings[company]).unsqueeze(0).float()
            comp_emb = F.normalize(comp_emb, p=2, dim=1)
            sim = torch.cosine_similarity(query_emb, comp_emb).item()
            dtw_vals.append(dist)
            cos_sims.append(sim)
        except:
            continue

    if not dtw_vals:
        print(f"No valid comparisons for {query_company}")
        return

    # === Step 3: Normalize DTW ===
    dtw_vals = MinMaxScaler().fit_transform(np.array(dtw_vals).reshape(-1,1)).flatten()

    # === Step 4: Bin by Cosine Similarity ===
    cos_sims = np.array(cos_sims)
    bins = np.linspace(0, 1, num_bins + 1)
    bin_indices = np.digitize(cos_sims, bins) - 1

    band_means = []
    band_labels = []
    for i in range(num_bins):
        mask = bin_indices == i
        if np.sum(mask) > 0:
            band_means.append(np.mean(np.array(dtw_vals)[mask]))
            band_labels.append(f"{bins[i]:.1f}–{bins[i+1]:.1f}")
        else:
            band_means.append(0)
            band_labels.append(f"{bins[i]:.1f}–{bins[i+1]:.1f}")

    # === Step 5: Color code bands ===
    band_colors = []
    for mean_cos in (bins[:-1] + (bins[1]-bins[0])/2):
        if mean_cos >= 0.66:
            band_colors.append("limegreen")  # High similarity
        elif mean_cos >= 0.33:
            band_colors.append("dodgerblue")  # Medium similarity
        else:
            band_colors.append("red")  # Low similarity

    # === Step 6: Bar + Line Plot ===
    plt.figure(figsize=(8,6))
    plt.bar(band_labels, band_means, color=band_colors, edgecolor="black")
    plt.plot(band_labels, band_means, color="black", marker="o", linewidth=2)  # Trend line
    plt.title(f"{block_name}: {query_company}\nDTW Distance by Cosine Similarity Band",
              fontsize=12, fontweight="bold")
    plt.xlabel("Cosine Similarity Band", fontsize=12, fontweight="bold")
    plt.ylabel("Average Normalized DTW Distance", fontsize=12, fontweight="bold")
    plt.grid(axis="y", linestyle="--", alpha=0.6)
    plt.xticks(fontsize=12, fontweight="bold")
    plt.yticks(fontsize=12, fontweight="bold")
    plt.show()

    # === Return as DataFrame ===
    df = pd.DataFrame({
        "Cosine_Band": band_labels,
        "Avg_DTW_Distance": band_means
    })
    return df

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="NWSA",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="TSLA",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AIZ",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AIZ",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="INCY",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="BXP",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AES",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AES",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HSIC",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HSIC",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HAS",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="APA",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="APA",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="CPB",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="CPB",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="FRT",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HRL",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HRL",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="HAS",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AOS",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AOS",
     series_dict=esg_core,
    embedding_dict=esg_context_embeddings_from_model,
    block_name="ESG",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="AOS",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model_12,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="REG",
     series_dict=ret_core,
    embedding_dict=ret_context_embeddings_from_model_12,
    block_name="RET",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="GOOGL",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="MSFT",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="XOM",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                           # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="ABBV",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                          # Number of similarity bands (default=5)
)

df_bands = company_similarity_band_barplot_n(
    query_company="SYK",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                        # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="ETN",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                        # Number of similarity bands (default=5)
)

# Example: For company "AAPL"
df_bands = company_similarity_band_barplot_n(
    query_company="LOW",
     series_dict=s_core,
    embedding_dict=s_context_embeddings_from_model,
    block_name="SOC",                    # "RET" for returns, "ESG" for ESG
    num_bins=5                       # Number of similarity bands (default=5)
)

from transformers import GPT2TokenizerFast

# Load base GPT‑2 tokenizer
base_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
base_vocab = set(base_tokenizer.get_vocab().keys())

# Load your custom tokenizer
custom_vocab = set(tokenizer_extended.get_vocab().keys())

# Tokens preserved
preserved_tokens = base_vocab.intersection(custom_vocab)
print(f"Preserved tokens: {len(preserved_tokens)}/{len(base_vocab)} ({len(preserved_tokens)/len(base_vocab)*100:.2f}%)")

from transformers import GPT2LMHeadModel

base_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
#custom_model = GPT2LMHeadModel.from_pretrained("path/to/custom_model")

# Compare embedding matrices for old tokens
original_emb = base_model.transformer.wte.weight.data[:len(base_vocab)]
new_emb = model.transformer.wte.weight.data[:len(base_vocab)]

diff = (original_emb - new_emb).abs().mean().item()
print(f"Mean difference in preserved token embeddings: {diff}")

sample_text = "The quick brown fox jumps over the lazy dog."
print(base_tokenizer.tokenize(sample_text))
print(tokenizer_extended.tokenize(sample_text))

import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def embed_text(model, tokenizer, text):
    model = model.to(device)
    inputs = tokenizer(text, return_tensors='pt').to(device)
    # Request hidden states explicitly
    outputs = model(**inputs, output_hidden_states=True)
    # Take the last hidden state ([-1]) and mean-pool over tokens
    last_hidden = outputs.hidden_states[-1]  # shape: [batch, seq, hidden]
    return last_hidden.mean(dim=1)  # mean over tokens

text = "ESG risk impacts corporate sustainability."
base_emb = embed_text(base_model, base_tokenizer, text)
new_emb = embed_text(model, tokenizer_extended, text)
sim = F.cosine_similarity(base_emb, new_emb).item()
print(f"Cosine similarity between base & extended model outputs: {sim:.4f}")

prompt = "The company's ESG risk is"

# Encode for each tokenizer
inputs = base_tokenizer(prompt, return_tensors="pt").to(device)
inputs_ext = tokenizer_extended(prompt, return_tensors="pt").to(device)

# === Generate ===
with torch.no_grad():
    base_output = base_model.generate(**inputs, max_new_tokens=20)
    extended_output = model.generate(**inputs_ext, max_new_tokens=20)
    print(base_output)
    print(extended_output)

# === Decode ===
print("Base Model:", base_tokenizer.decode(base_output[0]))
print("Extended Model:", tokenizer_extended.decode(extended_output[0]))

special_tokens_added = list(custom_vocab - base_vocab)
print(f"Custom special tokens ({len(special_tokens_added)}):", special_tokens_added[:20])  # show first 20
special_token_ids = [tokenizer_extended.convert_tokens_to_ids(tok) for tok in special_tokens_added]
print("Special token IDs:", special_token_ids[:20])

bad_words_ids = [[tid] for tid in special_token_ids]
output = model.generate(
    inputs["input_ids"],
    max_new_tokens=50,
    bad_words_ids=bad_words_ids  # Block all special tokens
)

prompt = "The company's ESG risk is"

# Encode for each tokenizer
inputs = base_tokenizer(prompt, return_tensors="pt").to(device)
inputs_ext = tokenizer_extended(prompt, return_tensors="pt").to(device)

# === Generate ===
with torch.no_grad():
    base_output = base_model.generate(**inputs, max_new_tokens=20)
    extended_output = model.generate(**inputs_ext, max_new_tokens=40, bad_words_ids=bad_words_ids,
                                     repetition_penalty=2.0,
        no_repeat_ngram_size=3,
        temperature=0.8,
        top_p=0.9)  # Block all special tokens)
    print(base_output)
    print(extended_output)

# === Decode ===
print("Base Model:", base_tokenizer.decode(base_output[0]))
print("Extended Model:", tokenizer_extended.decode(extended_output[0]))

def dual_mode_generate(model, tokenizer, prompt, mode="forecast", max_new_tokens=50):
    from transformers import GPT2TokenizerFast
    base_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    logits_processors = []
    if mode == "text":
        # Block any token starting with "<"
        from transformers import LogitsProcessor
        class RegexBan(LogitsProcessor):
            def __init__(self, tokenizer, prefix="<", penalty=-100):
                self.vocab = tokenizer.get_vocab()
                self.prefix = prefix
                self.penalty = penalty
            def __call__(self, input_ids, scores):
                for tok, idx in self.vocab.items():
                    if tok.startswith(self.prefix):
                        scores[:, idx] += self.penalty
                return scores
        logits_processors.append(RegexBan(tokenizer))

    output = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        logits_processor=logits_processors,
        do_sample=True,
        repetition_penalty=2.0,
        no_repeat_ngram_size=3,
        temperature=0.8,
        top_p=0.9
    )
    return tokenizer.decode(output[0])

#print(dual_mode_generate(base_model, tokenizer, "Company: TSLA ESG:", mode="forecast"))
print(dual_mode_generate(model, tokenizer_extended, "Write plain english.The company's ESG risk is", mode="text"))

from transformers import LogitsProcessor
import re

from transformers import LogitsProcessor
import re

class RegexBanLogitsProcessor(LogitsProcessor):
    """
    Penalizes (or bans) any token whose text matches a given regex.
    """
    def __init__(self, tokenizer, pattern, penalty=-100):
        self.vocab = tokenizer.get_vocab()
        self.pattern = re.compile(pattern)
        self.penalty = penalty

    def __call__(self, input_ids, scores):
        for tok, idx in self.vocab.items():
            if self.pattern.match(tok):
                scores[:, idx] += self.penalty
        return scores

regex_ban = RegexBanLogitsProcessor(
    tokenizer_extended,
    pattern=r"^<.*>|.*[\d\|#\!\*\=\~].*$",  # blocks <...> and noisy pieces
    penalty=-100
)

inputs = tokenizer_extended("The company's ESG risk is", return_tensors="pt").to(model.device)
output = model.generate(
    **inputs,
    max_new_tokens=50,
    logits_processor=[regex_ban],
    repetition_penalty=2.0,
    no_repeat_ngram_size=3,
    temperature=0.8,
    top_p=0.9,
    do_sample=True
)
print(tokenizer_extended.decode(output[0]))

from transformers import LogitsProcessor
import re
import torch

class FastRegexBanLogitsProcessor(LogitsProcessor):
    """
    Precomputes token IDs matching a regex and bans them at generation.
    Much faster than checking every step.
    """
    def __init__(self, tokenizer, pattern, penalty=-100):
        vocab = tokenizer.get_vocab()
        self.penalty = penalty
        self.blocked_ids = torch.tensor(
            [idx for tok, idx in vocab.items() if re.match(pattern, tok)],
            dtype=torch.long
        )
    def __call__(self, input_ids, scores):
        scores[:, self.blocked_ids] += self.penalty
        return scores

regex_ban = FastRegexBanLogitsProcessor(
    tokenizer_extended,
    pattern=r"^<.*>|.*[\d\|#\!\*\=\~].*$",
    penalty=-100
)
prompt = "Write a clean plain English explanation: The company's ESG risk is"

inputs = tokenizer_extended(prompt, return_tensors="pt").to(model.device)
output = model.generate(
    **inputs,
    max_new_tokens=50,
    logits_processor=[regex_ban],
    repetition_penalty=2.0,
    no_repeat_ngram_size=3,
    temperature=1.0,
    top_p=0.8,
    do_sample=True
)

print(tokenizer_extended.decode(output[0]))

from transformers import LogitsProcessor
import torch, re

# === Whitelist processor ===
class WhitelistLogitsProcessor(LogitsProcessor):
    """
    Allows only tokens matching a regex pattern (all others get -100 logits).
    """
    def __init__(self, tokenizer, allowed_pattern):
        vocab = tokenizer.get_vocab()
        self.allowed_ids = torch.tensor(
            [idx for tok, idx in vocab.items() if re.match(allowed_pattern, tok)],
            dtype=torch.long
        )
    def __call__(self, input_ids, scores):
        mask = torch.full_like(scores, -100.0)  # Mask everything
        mask[:, self.allowed_ids] = scores[:, self.allowed_ids]  # Keep allowed
        return mask

# === Safe generate ===
def safe_generate(model, tokenizer, prompt, mode="forecast", max_new_tokens=60):
    """
    Generates text or forecast values safely using one LoRA-tuned GPT-2 model.

    mode:
        "forecast" - unrestricted (for ESG/RET tokens)
        "text" - forces plain English only (blocks special tokens, digits, symbols)
    """
    logits_processors = None
    if mode == "text":
        # Build whitelist: only English letters, spaces, punctuation
        whitelist = WhitelistLogitsProcessor(
            tokenizer,
            allowed_pattern=r"^[A-Za-z\.,'\-\s]+$"
        )
        logits_processors = [whitelist]
        # Add steering to force natural language style
        prompt = "Write a coherent plain English sentence: " + prompt

    # Prepare input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Generate
    output = model.generate(
        **inputs,
        logits_processor=logits_processors,
        do_sample=True,
        temperature=1.2 if mode == "text" else 1.0,
        top_p=0.85 if mode == "text" else 1.0,
        repetition_penalty=2.5 if mode == "text" else 1.0,
        no_repeat_ngram_size=4 if mode == "text" else 0,
        max_new_tokens=max_new_tokens
    )
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Disable LoRA for clean text
if hasattr(model, "disable_adapter"):
    model.disable_adapter()

processor = WhitelistLogitsProcessor(
            tokenizer_extended,
             allowed_pattern=r"^[A-Za-z\.,'\-\s]+$"
           )
prompt = "Write a clean plain English explanation: The company's ESG risk is"
inputs = tokenizer_extended(prompt, return_tensors="pt").to(model.device)

output = model.generate(
    **inputs,
    logits_processor=[processor],
    max_new_tokens=50,
    do_sample=True,
    temperature=0.8,
    top_p=0.9
)

text_output = tokenizer_extended.decode(output[0], skip_special_tokens=True)
print("=== Clean Text Output ===")
print(text_output)

# Re-enable LoRA for forecasting later
if hasattr(model, "enable_adapter"):
    model.enable_adapter()

text_output

from transformers import GPT2LMHeadModel, GPT2Tokenizer

def create_base_gpt2_model():
    """
    Loads a clean GPT-2 model and tokenizer for free-text generation.
    No resizing, no LoRA, just the base model.
    """
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token  # GPT-2 doesn't have a PAD, reuse EOS
    model_b = GPT2LMHeadModel.from_pretrained("gpt2")
    model_b.config.pad_token_id = tokenizer.pad_token_id
    model_b.eval()
    return model_b, tokenizer

# === Example Usage ===
model_b, base_tokenizer = create_base_gpt2_model()

def generate_text(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test free-text generation
prompt = "The future of ESG investing is"
print(generate_text(model_b, base_tokenizer, prompt))