import warnings
from collections import defaultdict

import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F
from loguru import logger
from tqdm import tqdm
from knowledge_neurons import KnowledgeNeurons

# from me_attr import Attr
from me_cfg import Configure
from me_load import NeoLoader, Metric
from me_shared import DEVICE
from me_util import get_attr, compute_sss, freeze_params, constrained_lstsq, obtain_loss, Whitener, PCAProjector, \
    whitening, align, align2, find_most_attributed_layer
from me_util2 import get_blocks, inner_optimize_block, make_sft_batch, obtain_kwargs_blk, list_target_linears, OPTIMIZE


class Core:
    def __init__(self, cfg: Configure):
        self.cfg = cfg
        # self.tokenizer = NeoLoader.load_tokenizer(cfg.model_name)
        #
        # self.config, self.model, self.attrs = NeoLoader.load_model(cfg.model_name)
        # self.model.to(DEVICE)
        # # self.attr = Attr(self.cfg, self.model, self.tokenizer, self.attrs)
        # print(self.model)

        self.src_tokenizer = NeoLoader.load_tokenizer(cfg.src_model_name)
        src_config, src_model, src_attrs = NeoLoader.load_model(cfg.src_model_name)
        self.model_large = src_model.to(DEVICE)

        self.tgt_tokenizer = NeoLoader.load_tokenizer(cfg.tgt_model_name)
        tgt_config, tgt_model, tgt_attrs = NeoLoader.load_model(cfg.tgt_model_name)
        self.model_small = tgt_model.to(DEVICE)

        self.src_hidden_size = src_model.config.hidden_size
        self.src_embeddings_matrix = self.get_input_semantic_bases(src_model, src_attrs)
        self.src_lm_head_matrix = self.get_output_semantic_bases(src_model, src_attrs)
        self.tgt_hidden_size = tgt_model.config.hidden_size
        self.tgt_embeddings_matrix = self.get_input_semantic_bases(tgt_model, tgt_attrs)
        self.tgt_lm_head_matrix = self.get_output_semantic_bases(tgt_model, tgt_attrs)

        lm_head_matrix = get_attr(src_model, src_attrs['lm_head'])
        self.src_lm_head = lm_head_matrix.weight.detach().T  # .cpu()

        lm_head_matrix = get_attr(tgt_model, tgt_attrs['lm_head'])
        self.tgt_lm_head = lm_head_matrix.weight.detach().T  # .cpu()

        self.switch_things('small')
        # TODO
        # B = F.normalize(B, dim=1)  # 保证每个 token 方向单位范数
        # self.in_bases = whitening(self.get_input_semantic_bases())
        # self.out_bases = whitening(self.get_output_semantic_bases())

        # # ...白化
        # self.whitener = Whitener(k=self.tgt_hidden_size, l2norm=False)
        # self.whitener2 = Whitener(k=self.tgt_hidden_size, l2norm=False)
        # 仅中心化
        self.whitener = PCAProjector(k=self.tgt_hidden_size)
        self.whitener2 = PCAProjector(k=self.tgt_hidden_size)
        # # 中心化+缩放
        # self.whitener = PCAProjector(k=self.tgt_hidden_size, scale=True)
        # self.whitener2 = PCAProjector(k=self.tgt_hidden_size, scale=True)

        # # ... for self._reinit()
        # self.inplace_weights = {
        #     name: parameters
        #     for name, parameters in self.model.named_parameters()
        #     if self.attrs['ff_output'] + '.weight' in name
        # }
        # self.backup_weights = {k: v.detach().clone() for k, v in self.inplace_weights.items()}

    # def _reinit(self):
    #     with torch.no_grad():
    #         for k, v in self.inplace_weights.items():
    #             v[...] = self.backup_weights[k]

    def switch_things(self, tag='small'):
        if tag == 'small':
            self.tokenizer = self.tgt_tokenizer
            self.model = self.model_small
        elif tag == 'large':
            self.tokenizer = self.src_tokenizer
            self.model = self.model_large
        else:
            raise NotImplementedError

    def inference(self, prompt, max_new_tokens=1):
        if "CodeQwen" in self.cfg.model_name or "Qwen2.5-Coder" in self.cfg.model_name:
            inputs = self.tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to(DEVICE)
        else:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(DEVICE)

        with torch.no_grad():  # 禁用梯度计算，节省内存和计算资源
            generated_ids = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                output_scores=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        # generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        # generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # 获取token ID序列 (Tensor格式)
        label_sequence = generated_ids[0].tolist()[0]
        label_sequence = label_sequence[-max_new_tokens:]
        # print("Token IDs:", label_sequence)
        # 获取文本序列
        # token_sequence = self.tokenizer.decode(label_sequence, skip_special_tokens=True)
        token_sequence = [self.tokenizer.decode(label, skip_special_tokens=True) for label in label_sequence]
        # print("Text:", token_sequence)
        return label_sequence, token_sequence

    def get_input_semantic_bases(self, model, attrs):
        embeddings_matrix = get_attr(model, attrs['embedding'])
        embeddings_matrix = embeddings_matrix.weight.detach()  # .cpu()
        return embeddings_matrix

    def get_output_semantic_bases(self, model, attrs):
        lm_head_matrix = get_attr(model, attrs['lm_head'])
        lm_head_matrix = lm_head_matrix.weight.detach()  # .cpu()
        return torch.linalg.pinv(lm_head_matrix).T

    def run_align(self):
        SAMPLE_TEXT = "The capital of France is"
        SAMPLE_GT = "Paris"
        prompts = [SAMPLE_TEXT]
        gts = [SAMPLE_GT]

        # 加载模型
        tokenizer, model_large, model_small = load_models()

        # 寻找关键层
        print("\nFinding most attributed layers...")
        teacher_kn = KnowledgeNeurons(
            model_large,
            tokenizer,
            model_type="llama",
            is_teacher=True,
            device=DEVICE)
        student_kn = KnowledgeNeurons(
            model_small,
            tokenizer,
            model_type="llama",
            is_teacher=False,
            device=DEVICE)

        top_ffn_layer_idx = teacher_kn.get_top_attribute_layers(prompts, gts, top_cnt=3)
        top_student_ffn_layer_idx = student_kn.get_top_attribute_layers(prompts, gts, top_cnt=3)
        top_ffn_layer_idx = sorted(top_ffn_layer_idx)
        top_student_ffn_layer_idx = sorted(top_student_ffn_layer_idx)
        # 1
        m_layer = top_ffn_layer_idx[0]
        n_layer = top_student_ffn_layer_idx[0]
        # 2
        m_layer = find_most_attributed_layer(model_large, SAMPLE_TEXT, tokenizer)
        n_layer = find_most_attributed_layer(model_small, SAMPLE_TEXT, tokenizer)
        print(f"Large model key layer: {m_layer}, Small model key layer: {n_layer}")
        # 3
        m_layer = -1
        n_layer = -1

    # option1: teacher-LM last layer => student-LM last layer
    # option2: teacher-LM each layer => student-LM each layer
    def workflow(self, source: str, target: str, epoch_num=10):
        # we reinit model to avoid interferes
        # self._reinit()
        self.model_small.eval()

        target_labels = [label for label in self.tokenizer.encode(target, add_special_tokens=False)]
        target_tokens = [self.tokenizer.decode(label, skip_special_tokens=True) for label in target_labels]
        logger.warning(f'___ {target_labels=}')
        logger.warning(f'___ {target_tokens=}')

        truncation_lens = len(target_labels)
        argmax_labels, argmax_tokens = self.inference(source, max_new_tokens=truncation_lens)
        pre_gen_tokens = argmax_tokens

        logger.info(f'___ {argmax_tokens=}')
        logger.info(f'___ {argmax_labels=}')

        pairs = [(source, target)]
        # 一次性构建 batch（dict，供 model(**batch) 使用）
        batch = make_sft_batch(self.tokenizer, pairs, DEVICE)

        # self.switch_things('large')
        # self.model_large.eval()
        # logger.info('checking the performance of LLMs')
        # argmax_labels, argmax_tokens = self.inference(source, max_new_tokens=truncation_lens)
        # post_gen_tokens = argmax_tokens
        # Metric.contrast_scoring([pre_gen_tokens], [post_gen_tokens], [target_tokens])
        # # Metric.contrast_scoring2([pre_gen_tokens], [post_gen_tokens], [target_tokens])

        # self.switch_things('small')
        # self.model_large.eval()
        # logger.info('checking the performance of LLMs')
        # argmax_labels, argmax_tokens = self.inference(source, max_new_tokens=truncation_lens)
        # post_gen_tokens = argmax_tokens
        # Metric.contrast_scoring([pre_gen_tokens], [post_gen_tokens], [target_tokens])
        # # Metric.contrast_scoring2([pre_gen_tokens], [post_gen_tokens], [target_tokens])

        # B = 1
        # L = 128
        # H = 1024

        # ...
        llm_blocks = get_blocks(self.model_large)
        slm_blocks = get_blocks(self.model_small)
        with torch.no_grad():
            outputs = self.model_large(**batch, output_hidden_states=True, use_cache=False)
            llm_reprs = outputs.hidden_states
        with torch.no_grad():
            outputs = self.model_small(**batch, output_hidden_states=True, use_cache=False)
            slm_reprs = outputs.hidden_states
        # 2) Prepare inputs exactly like HF once per batch and reuse
        slm_kwargs_blk = obtain_kwargs_blk(self.model_small, batch)

        # ...
        a_last_layer_outputs = llm_reprs[-1]
        a_outputs = a_last_layer_outputs
        b_last_layer_outputs = slm_reprs[-1]
        b_outputs = b_last_layer_outputs
        b_last_layer_inputs = slm_reprs[-2]
        h = b_last_layer_inputs



        # ...
        large_output_bases = self.src_lm_head_matrix.numpy()
        small_output_bases = self.tgt_lm_head_matrix.numpy()
        # # ...
        # input_ids = self.src_tokenizer(source, return_tensors="pt").input_ids.to(DEVICE)
        # # # ...
        # # m_layer = -1
        # # n_layer = -1
        # # _, a_outputs = collect_ffn_data(self.model_large, m_layer, input_ids)
        # # _, b_outputs = collect_ffn_data(self.model_small, n_layer, input_ids)
        # # ...
        # reducer = self.whitener
        # # ...
        # aligned_output_out = align(large_output_bases, small_output_bases, a_outputs, reducer)
        # sim_out = torch.cosine_similarity(aligned_output_out.flatten(), b_outputs.flatten(), dim=0)

        # # ...
        # reducer = self.whitener
        # # ...
        # reducer2 = self.whitener2
        # small_output_bases = torch.from_numpy(small_output_bases)
        # small_output_bases = reducer2.fit_transform(small_output_bases)
        # small_output_bases = small_output_bases.numpy()
        #
        # X_reduced, a_outputs = align(large_output_bases, small_output_bases, a_outputs, reducer)



        a_outputs = a_outputs.squeeze(0)
        b_outputs = b_outputs.squeeze(0)
        a_outputs = a_outputs[-truncation_lens:, :]
        b_outputs = b_outputs[-truncation_lens:, :]



        a_logits = F.softmax(a_outputs @ self.src_lm_head, dim=-1)
        #
        b_logits = F.softmax(b_outputs @ self.tgt_lm_head, dim=-1)
        # ...
        semantic_targets = [large_output_bases[label] for label in target_labels]
        semantic_targets = torch.tensor(semantic_targets)
        # logits = semantic_targets @ self.src_lm_head
        logits = F.softmax(semantic_targets @ self.src_lm_head, dim=-1)
        # ...
        semantic_targets = [small_output_bases[label] for label in target_labels]
        semantic_targets = torch.tensor(semantic_targets)
        # logits = semantic_targets @ self.tgt_lm_head
        logits = F.softmax(semantic_targets @ self.tgt_lm_head, dim=-1)

        argmax_labels = torch.argmax(logits, dim=-1).tolist()
        argmax_tokens = [self.tokenizer.decode(label, skip_special_tokens=True) for label in argmax_labels]
        logger.debug(f'__T: {argmax_labels=}')
        logger.debug(f'__T: {argmax_tokens=}')
        post_gen_tokens = argmax_tokens
        print(f'{len(pre_gen_tokens)=}')
        print(f'{len(post_gen_tokens)=}')
        print(f'{len(target_tokens)=}')
        Metric.contrast_scoring([pre_gen_tokens], [post_gen_tokens], [target_tokens])
        Metric.contrast_scoring2([pre_gen_tokens], [post_gen_tokens], [target_tokens])
        logger.info('=!' * 20)

        # logger.warning(f'{target_labels=}')
        # logger.warning(f'{target_tokens=}')

        # label_s = 2023
        # label_t = 369
        # for label in [label_s, label_t]:
        #     semantic_targets = [large_output_bases.numpy()[label]]
        #     semantic_targets = torch.tensor(semantic_targets)
        #     logits = F.softmax(semantic_targets @ self.src_lm_head, dim=-1)
        #     prob_s = logits[0, label_s]
        #     prob_t = logits[0, label_t]
        #     print(f'{prob_s=}, {prob_t=}')
        # # argmax_labels = torch.argmax(logits, dim=-1).tolist()

        large_output_bases = torch.from_numpy(large_output_bases)
        small_output_bases = torch.from_numpy(small_output_bases)

        # self.whitener = PCAProjector(scale=False)
        # large_output_bases = self.whitener.fit_transform(large_output_bases)
        # self.whitener2 = PCAProjector(scale=False)
        # small_output_bases = self.whitener2.fit_transform(small_output_bases)

        # large_output_bases = whitening(large_output_bases)
        # small_output_bases = whitening(small_output_bases)

        # logger.info(f"before {a_outputs.shape=}")
        # a_outputs = self.whitener.transform(a_outputs)  # Assumes same feature space!
        # logger.info(f"after {a_outputs.shape=}")

        # TODO optimize the efficiency?
        # TODO run experiments with flexible layer-mapping?
        # TODO decide the RQs
        # TODO submit the experiments
        aligned_output_out = align2(a_outputs, large_output_bases, small_output_bases)
        # large_output_bases = large_output_bases.numpy()
        # small_output_bases = small_output_bases.numpy()

        # # old ...
        # aligned_output_out = align_legacy(large_output_bases, small_output_bases, a_outputs, reducer)
        tt = aligned_output_out
        tt = tt[-truncation_lens:, :]
        # TODO ...
        # tt = semantic_targets

        block = slm_blocks[-1]


        # 直接用 SGD 更新需要优化的 block 参数
        # pick modules to adapt
        name2mod = {name: mod for name, mod in list_target_linears(block) if OPTIMIZE.get(name, False)}
        params_to_update = [getattr(mod, 'weight') for mod in name2mod.values() if hasattr(mod, 'weight')]
        # TODO how to find a proper LR ... (xxx)
        # optimizer = torch.optim.SGD(params_to_update, lr=5e-2)
        optimizer = torch.optim.SGD(params_to_update, lr=1e-2)



        self.model_small.train()
        for epoch_idx in range(epoch_num * 10):
            logger.success('=' * 9 + f'{epoch_idx}' + '=' * 9)

            optimizer.zero_grad()

            # # Sequential inner loops with chaining
            # for idx, block in enumerate(slm_blocks):
            #     if idx == 0:
            #         h = self.model_small.model.embed_tokens(batch["input_ids"])
            #
            #     # it means, the dependency required being captured (backward dependency => forward dependency)
            #     # TODO decide which model layer to optimize (...)
            #     h = inner_optimize_block(
            #         block, h, tt,
            #         truncation_lens=truncation_lens,
            #         kwargs_blk=slm_kwargs_blk,
            #     )

            # TODO using semantic bases as supervision signals... (tt => xtt)
            yy = inner_optimize_block(
                block, h, tt,
                a_outputs,
                b_outputs,
                self.src_lm_head,
                self.tgt_lm_head,
                target_labels,
                truncation_lens=truncation_lens,
                kwargs_blk=slm_kwargs_blk,
            )

            optimizer.step()



            # 输入可导
            yy = yy @ self.tgt_lm_head
            logits = F.softmax(yy, dim=-1)
            argmax_labels = torch.argmax(logits, dim=-1).tolist()
            argmax_tokens = [self.tokenizer.decode(label, skip_special_tokens=True) for label in argmax_labels]
            logger.debug(f'epoch {argmax_labels=}')
            logger.debug(f'epoch {argmax_tokens=}')
            post_gen_tokens = argmax_tokens
            Metric.contrast_scoring([pre_gen_tokens], [post_gen_tokens], [target_tokens])
            Metric.contrast_scoring2([pre_gen_tokens], [post_gen_tokens], [target_tokens])


        self.model_small.eval()
        argmax_labels, argmax_tokens = self.inference(source, max_new_tokens=truncation_lens)
        post_gen_tokens = argmax_tokens

        return pre_gen_tokens, post_gen_tokens, target_tokens
