import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
import matplotlib.pyplot as plt
from sympy.physics.quantum.circuitplot import Line2D

from me_load import NeoLoader, _load_he
from me_shared import DEVICE, REPORT_OUTS_DIR
from me_util import get_attr, format_score


class Core:
    def __init__(self, model_name):
        self.tokenizer = NeoLoader.load_tokenizer(model_name)

        self.config, self.model, self.attrs = NeoLoader.load_model(model_name)
        self.model.to(DEVICE)

        model_layers = get_attr(self.model, self.attrs['layers'])
        self.num_layers = len(model_layers)

        self.embeddings_matrix = self.get_embedding_matrix()
        self.lm_head_matrix = self.get_lm_head_matrix()

    def get_embedding_matrix(self):
        embeddings_matrix = get_attr(self.model, self.attrs['embedding'])
        embeddings_matrix = embeddings_matrix.weight.detach()  # .cpu()
        return embeddings_matrix

    def get_lm_head_matrix(self):
        lm_head_matrix = get_attr(self.model, self.attrs['lm_head'])
        lm_head_matrix = lm_head_matrix.weight.detach()  # .cpu()
        lm_head_matrix = torch.linalg.pinv(lm_head_matrix.T)
        return lm_head_matrix

    def switch_xxx(self, option='input-side'):
        if option == 'input-side':
            # input-side
            self.layer_anchors = self.embeddings_matrix
        elif option == 'output-side':
            # output-side
            self.layer_anchors = self.lm_head_matrix
        else:
            raise NotImplementedError

    def pipeline(self, prompt):
        # prepare the data (or the dataloader...)
        encoding = self.tokenizer.encode_plus(prompt, add_special_tokens=False, truncation=True, return_tensors="pt")
        input_ids = encoding["input_ids"].to(DEVICE)
        # if len(input_ids.shape) == 2:
        #     input_ids = input_ids.squeeze(0)
        outputs = self.model(input_ids, output_hidden_states=True, return_dict=True)

        # TODO consider doing the layer-by-layer stuff ??? (no when the erp for target/argmax labels is the same)
        hidden_states = outputs.hidden_states
        embeds_hats = [hidden_state[:, -1, :] for hidden_state in hidden_states]
        logger.warning(f'{len(embeds_hats)=}')

        layered_simis = list()
        for layer_idx in range(self.num_layers):
            # ...
            embeds_hat = embeds_hats[layer_idx + 1].detach()
            simis = F.cosine_similarity(embeds_hat, self.layer_anchors)
            simis = simis.unsqueeze(0)
            semantic_parts = simis @ self.layer_anchors

            logger.debug(f'{embeds_hat.shape=}')
            logger.debug(f'{semantic_parts.shape=}')
            simi = F.cosine_similarity(embeds_hat, semantic_parts, dim=-1).item()
            logger.debug(f'{layer_idx=}, {format_score(simi)=}')
            layered_simis.append(simi)
        return layered_simis

    def plot_annotated_curve(self, data_dict):
        plt.figure(figsize=(10, 6))

        # Get the maximum length of the value lists to know where to draw vertical lines
        model_layers = get_attr(self.model, self.attrs['layers'])
        num_layers = len(model_layers)
        # Draw vertical dashed lines for each integer position on the x-axis
        for i in range(num_layers):
            plt.axvline(x=i, color='gray', linestyle='--', linewidth=0.5)

        # Adding labels and title
        plt.xlabel('Model-Layer Index', fontsize=14, labelpad=10)
        plt.ylabel('Cosine Similarity', fontsize=14, labelpad=10)
        plt.yticks(np.arange(0.0, 1.1, 0.1), fontsize=14)
        # plt.yticks(y_layers, labels=[f'Layer {i}' for i in range(len(y_layers))])
        # plt.title('Empirical Validation of Semantics Decomposition (CodeGen)', fontsize=18)

        # ...
        plt.ylim(-0.1, 1.0)
        plt.grid(True)
        plt.tight_layout()

        # Set x-axis ticks to show integer values starting from 1
        plt.xticks(ticks=range(num_layers),
                   labels=[f'Layer {i + 1}' for i in range(num_layers)],
                   fontsize=14,
                   rotation=45,  # Rotate the labels
                   ha='right')  # Align the labels to the right

        # Iterate through the dictionary
        sorted_data = sorted(data_dict.items())
        colors = ['blue', 'red', 'orange']
        for (label, batched_reprs), color in zip(sorted_data, colors):
            # print(f'{label}\t' + '\t'.join([str(value) for value in values]))
            # Plot the values (assuming they are in sequential order)
            for reprs in batched_reprs:
                # plt.plot(reprs, label=f'{label}', color=color, alpha=0.5, linewidth=3)
                plt.plot(reprs, color=color, alpha=0.1, linewidth=1)

        # Add a legend to make the keys clear
        legend_elements = [
            Line2D([0], [0], color='blue', linewidth=3, label='Input-Side Semantic Bases'),
            Line2D([0], [0], color='red', linewidth=3, label='Output-Side Semantic Bases'),
        ]
        plt.legend(handles=legend_elements, loc='best', fontsize=14)

        # Display the plot
        subfolder = REPORT_OUTS_DIR / 'plot'
        subfolder.mkdir(parents=True, exist_ok=True)

        plt.tight_layout()
        plt.show()


if __name__ == '__main__':
    test_texts, test_codes = _load_he()

    # do generation for each datum
    num_iteration = len(test_texts)
    core = Core('Qwen/Qwen3-0.6B')

    dict_simis = dict()
    for option in ['input-side', 'output-side']:
        core.switch_xxx(option)
        batched_simis = list()
        for datum_idx, (test_text, test_code) in enumerate(zip(test_texts, test_codes)):
            logger.success('@' * 9 + f'{datum_idx + 1}/{num_iteration}' + '@' * 9)
            layered_simis = core.pipeline(test_text)
            batched_simis.append(layered_simis)
        dict_simis[option] = batched_simis
    core.plot_annotated_curve(dict_simis)
