import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
import matplotlib.pyplot as plt
from rich.containers import Lines
from sympy.physics.quantum.circuitplot import Line2D

from me_cfg import Configure
from me_load import NeoLoader, Metric
from me_shared import DEVICE, REPORT_OUTS_DIR
from me_util import get_attr, format_score, compute_kernel_bias_torch, transform_and_normalize_torch, whitening, \
    PCAProjector, Whitener


def theta_lerp(v1_group, v2_group, erp):
    # Check if v1_group contains only one vector; if so, expand it to match v2_group's size
    if v1_group.ndim == 1:
        # v1_group has one vector, expand it to match the number of vectors in v2_group
        v1_group = v1_group.unsqueeze(0).expand(v2_group.shape[0], -1)

    # Step 1: Normalize the vectors
    v1_group = F.normalize(v1_group, p=2, dim=1)  # Ensure v1 is a unit vector (batch-wise)
    v2_group = F.normalize(v2_group, p=2, dim=1)  # Ensure v2 is a unit vector (batch-wise)

    # Step 2: Compute cosine similarity and angle between v1_group and v2_group
    cosine_sim = torch.sum(v1_group * v2_group, dim=1)  # Batch-wise dot product
    theta = torch.acos(cosine_sim)  # Batch-wise angle in radians

    # Step 3: Generate interpolation steps
    # Linear interpolation of the angle
    theta_t = erp * theta  # Batch-wise interpolation

    # Step 4: Construct the interpolated vector
    # v2_perpendicular: v2 orthogonal component (for each pair)
    v2_perpendicular = F.normalize(v2_group - (v1_group * cosine_sim.unsqueeze(1)), p=2, dim=1)

    # Batch-wise cosine and sine interpolation
    v_t = (torch.cos(theta_t).unsqueeze(1) * v1_group) + (torch.sin(theta_t).unsqueeze(1) * v2_perpendicular)
    interpolated_vectors = F.normalize(v_t, p=2, dim=1)  # Normalize for unit length
    return interpolated_vectors


class Core:
    def __init__(self, cfg: Configure):
        self.cfg = cfg
        self.tokenizer = NeoLoader.load_tokenizer(self.cfg.model_name)

        self.config, self.model, self.attrs = NeoLoader.load_model(self.cfg.model_name)
        self.model.to(DEVICE)

        model_layers = get_attr(self.model, self.attrs['layers'])
        self.num_layers = len(model_layers)

        # ...
        self.embeddings_matrix = self.get_embedding_matrix()
        self.lm_head_matrix = self.get_lm_head_matrix()

        # self.whitener = Whitener()
        # self.whitener2 = Whitener()
        # self.whitener = PCAProjector(scale=False)
        # self.whitener2 = PCAProjector(scale=False)
        # self.embeddings_matrix = self.whitener.fit_transform(self.embeddings_matrix)
        # self.lm_head_matrix = self.whitener2.fit_transform(self.lm_head_matrix)

    def get_embedding_matrix(self):
        embeddings_matrix = get_attr(self.model, self.attrs['embedding'])
        embeddings_matrix = embeddings_matrix.weight.detach()  # .cpu()
        # embeddings_matrix = F.normalize(embeddings_matrix, p=2, dim=-1)
        # embeddings_matrix = torch.randn_like(embeddings_matrix)
        return embeddings_matrix

    def get_lm_head_matrix(self):
        lm_head_matrix = get_attr(self.model, self.attrs['lm_head'])
        lm_head_matrix = lm_head_matrix.weight.detach()  # .cpu()
        lm_head_matrix = torch.linalg.pinv(lm_head_matrix.T)
        # lm_head_matrix = F.normalize(lm_head_matrix, p=2, dim=-1)
        # lm_head_matrix = torch.randn_like(lm_head_matrix)
        return lm_head_matrix

    def switch_xxx(self, option='input-side'):
        if option == 'input-side':
            # input-side
            self.layer_anchors = self.embeddings_matrix
            # # ...
            # self.kernel, self.bias = compute_kernel_bias_torch(self.layer_anchors)
            # self.layer_anchors = transform_and_normalize_torch(self.layer_anchors, self.kernel, self.bias)
        elif option == 'output-side':
            # output-side
            self.layer_anchors = self.lm_head_matrix
            # # ...
            # self.kernel, self.bias = compute_kernel_bias_torch(self.layer_anchors)
            # self.layer_anchors = transform_and_normalize_torch(self.layer_anchors, self.kernel, self.bias)
        elif option == 'xpolation':
            # # based on the scaling rule...
            # import math
            # scaling_factor = 0.0
            # scalings = [math.pow(layer_idx + 1, scaling_factor) for layer_idx in range(num_layers)]
            # erp_ratios = [scaling / scalings[-1] for scaling in scalings]
            # based on the linear assumption...
            erp_ratios = [(layer_idx + 1) / self.num_layers for layer_idx in range(self.num_layers)]

            # interpolation
            self.dict_layer_anchors = dict()
            for layer_idx in range(self.num_layers):
                erp_ratio = erp_ratios[layer_idx]

                # layer_anchors = torch.lerp(self.embeddings_matrix, self.lm_head_matrix, erp_ratio)
                layer_anchors = theta_lerp(self.embeddings_matrix, self.lm_head_matrix, erp_ratio)  # ???
                # layer_anchors = whitening(layer_anchors)
                self.dict_layer_anchors[layer_idx] = layer_anchors

                # anchors = list()
                # for inputs_anchor, outputs_anchor in zip(self.embeddings_matrix, self.lm_head_matrix):
                #     anchor = torch.lerp(inputs_anchor, outputs_anchor, normed_scalings[layer_idx])
                #     anchors.append(anchor)
                # layer_anchors.append(torch.stack(anchors, dim=0))
                # layer_anchors = torch.stack(layer_anchors, dim=0).detach().to(DEVICE)
            # logger.debug(f'{layer_anchors.shape=}')
        else:
            raise NotImplementedError

    def pipeline(self, prompt, option):
        # prepare the data (or the dataloader...)
        encoding = self.tokenizer.encode_plus(prompt, add_special_tokens=False, truncation=True, return_tensors="pt")
        input_ids = encoding["input_ids"].to(DEVICE)
        # if len(input_ids.shape) == 2:
        #     input_ids = input_ids.squeeze(0)
        outputs = self.model(input_ids, output_hidden_states=True, return_dict=True)

        # TODO consider doing the layer-by-layer stuff ??? (no when the erp for target/argmax labels is the same)
        hidden_states = outputs.hidden_states
        embeds_hats = [hidden_state[:, -1, :] for hidden_state in hidden_states]
        logger.warning(f'{len(embeds_hats)=}')

        layered_simis = list()
        for layer_idx in range(self.num_layers):
            # ...
            embeds_hat = embeds_hats[layer_idx + 1].detach()

            # # TODO shall we also whiten latent representations ???
            # embeds_hat = transform_and_normalize_torch(embeds_hat, self.kernel, self.bias)

            # ... for test
            if option == 'xpolation':
                self.layer_anchors = self.dict_layer_anchors[layer_idx]

            # ...
            # embed_norm = F.normalize(embeds_hat, dim=-1)  # shape [1, 1024]
            # anchors_norm = F.normalize(self.layer_anchors, dim=-1)  # shape [151936, 1024]
            # simis = torch.einsum('bd,nd->bn', embed_norm, anchors_norm)

            # embeds_hat = F.normalize(embeds_hat, dim=-1)  # (166, 1024)
            # layer_anchors = F.normalize(self.layer_anchors, dim=-1)  # (151936, 1024)
            # simis = embeds_hat @ self.layer_anchors.T

            simis = F.cosine_similarity(embeds_hat, self.layer_anchors)

            # logger.warning(f'{format_score(torch.sum(simis).item())=}, {format_score(simi.item())=}')
            simis = simis.unsqueeze(0)
            # simis = F.softmax(simis).unsqueeze(0)  # softmax computation ignores the effects of centering, so we use it!

            # logger.debug(f'{layer_anchors.shape=}')
            logger.debug(f'{simis.shape=}')
            logger.debug(f'{self.layer_anchors.shape=}')
            semantic_parts = simis @ self.layer_anchors
            # logger.debug(f'{semantic_parts.shape=}')
            # logger.warning(f'{torch.sum(simis)=}')

            # semantic_combined = torch.sum(semantic_parts, dim=0)

            # logger.debug(f'{semantic_combined.shape=}')
            logger.debug(f'{embeds_hat.shape=}')
            logger.debug(f'{semantic_parts.shape=}')
            simi = F.cosine_similarity(embeds_hat, semantic_parts, dim=-1).item()
            logger.debug(f'{layer_idx=}, {format_score(simi)=}')
            layered_simis.append(simi)

            # simis = F.cosine_similarity(embeds_hat, semantic_parts, dim=-1)
            # logger.warning(f'{simis.shape=}')
            # logger.warning(f'{format_score(simis.max().item())=}')
            # logger.warning(f'{format_score(simis.min().item())=}')
            # logger.warning(f'{format_score(simis.mean().item())=}')
            # logger.warning(f'{format_score(simis.median().item())=}')

        # import sys
        # sys.exit(999)

        return layered_simis

    def plot_annotated_curve(self, data_dict):
        plt.figure(figsize=(10, 6))

        # Get the maximum length of the value lists to know where to draw vertical lines
        model_layers = get_attr(self.model, self.attrs['layers'])
        num_layers = len(model_layers)
        # Draw vertical dashed lines for each integer position on the x-axis
        for i in range(num_layers):
            plt.axvline(x=i, color='gray', linestyle='--', linewidth=0.5)

        # Adding labels and title
        plt.xlabel('Model-Layer Index', fontsize=14, labelpad=10)
        plt.ylabel('Cosine Similarity', fontsize=14, labelpad=10)
        plt.yticks(np.arange(0.0, 1.1, 0.1), fontsize=14)
        # plt.yticks(y_layers, labels=[f'Layer {i}' for i in range(len(y_layers))])
        # plt.title('Empirical Validation of Semantics Decomposition (CodeGen)', fontsize=18)

        # ...
        plt.ylim(-0.1, 1.0)
        plt.grid(True)
        plt.tight_layout()

        # Set x-axis ticks to show integer values starting from 1
        plt.xticks(ticks=range(num_layers),
                   labels=[f'Layer {i + 1}' for i in range(num_layers)],
                   fontsize=14,
                   rotation=45,  # Rotate the labels
                   ha='right')  # Align the labels to the right

        # Iterate through the dictionary
        sorted_data = sorted(data_dict.items())
        colors = ['blue', 'red', 'orange']
        for (label, batched_reprs), color in zip(sorted_data, colors):
            # print(f'{label}\t' + '\t'.join([str(value) for value in values]))
            # Plot the values (assuming they are in sequential order)
            for reprs in batched_reprs:
                # plt.plot(reprs, label=f'{label}', color=color, alpha=0.5, linewidth=3)
                # plt.plot(reprs, color=color, alpha=0.1, linewidth=1)
                plt.plot(reprs, color=color, alpha=0.5, linewidth=1)

                # # Choose a middle index for annotation
                # middle_index = 0  # len(values) // 2
                # middle_value = reprs[middle_index]
                #
                # # Annotate the plot nearby the middle data point
                # plt.text(middle_index, middle_value, label, fontsize=9, verticalalignment='bottom',
                #          horizontalalignment='right')
                #
                # # # Annotate the plot with the corresponding key (label) at the last data point
                # # plt.text(len(values) - 1, values[-1], label, fontsize=9, verticalalignment='bottom')

        # Add a legend to make the keys clear
        legend_elements = [
            Line2D([0], [0], color='blue', linewidth=3, label='Input-Side Semantic Bases'),
            Line2D([0], [0], color='red', linewidth=3, label='Output-Side Semantic Bases'),
        ]
        plt.legend(handles=legend_elements, loc='best', fontsize=14)
        # plt.legend(loc='best')
        # plt.legend(loc='upper left')

        # Display the plot
        subfolder = REPORT_OUTS_DIR / 'plot'
        subfolder.mkdir(parents=True, exist_ok=True)
        save_figure_file = subfolder / f'{self.cfg.model_name}_empirical_validation.png'

        plt.tight_layout()
        # plt.savefig(save_figure_file, dpi=300, bbox_inches='tight')
        plt.show()
