import json
import torch
from pathlib import Path
from tqdm import tqdm
import numpy as np

import scipy.spatial
from scipy.stats import entropy
import math
import json

from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
import torch
import torch.nn.functional as F
import numpy as np
jsonl_path = Path("/home/atuin/b232dd/b232dd25/DeepSeek-R1-Distill-Qwen-7B/GSM8K/GSM8K_test_output.jsonl")
pt_dir = Path("/home/atuin/b232dd/b232dd25/DeepSeek-R1-Distill-Qwen-7B/GSM8K/GSM8K_all_tokens")
pt_dir_2  = Path("/home/atuin/b232dd/b232dd25/DeepSeek-R1-Distill-Qwen-7B/GSM8K/GSM8K_all_tokens")
output_path = Path("/home/hpc/b232dd/b232dd23/CoT-Kinetics/GSM8K_7B_without.jsonl")

def compute_dynamics_with_scores(hidden: torch.Tensor,ppl,entropy_id,maxprob):
    t= 1.0
    d1 = hidden[1:] - hidden[:-1]
    d2 = hidden[2:] - 2 * hidden[1:-1] + hidden[:-2]
    rep_start = hidden[0]
    rep_end = hidden[-1]
    norm_denominator = torch.norm(rep_end - rep_start, p=2) + 1e-8
    d1 = d1 / norm_denominator
    d2 = d2 / norm_denominator
    d1_norm = torch.norm(d1, dim=1)
    d2_norm = torch.norm(d2, dim=1)
    d1_mean = d1_norm.mean().item()
    d2_mean = d2_norm.mean().item()
    return {
        "CoT-Kinetics": d1_mean+d2_mean-t*entropy_id
    }

def compute_coe_metrics(hidden_states: torch.Tensor):

    layer_num = hidden_states.size(0)
    rep_start = hidden_states[0]
    rep_end = hidden_states[-1]
    norm_denominator = torch.norm(rep_end - rep_start, p=2) + 1e-8

    al_repdiff = hidden_states[1:] - hidden_states[:-1]  # [L-1, D]
    al_repdiff_norm = torch.norm(al_repdiff, dim=1) / norm_denominator  # [L-1]
    al_repdiff_ave = al_repdiff_norm.mean().item()
    cos_sim = torch.dot(rep_start, rep_end) / (torch.norm(rep_start) * torch.norm(rep_end) + 1e-8)
    cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
    norm_denominator_angle = math.acos(cos_sim.item())

    al_semdiff = []
    for i in range(layer_num - 1):
        a = hidden_states[i + 1]
        b = hidden_states[i]
        cos_sim = torch.dot(a, b) / (torch.norm(a) * torch.norm(b) + 1e-8)
        cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
        angle = math.acos(cos_sim.item())
        al_semdiff.append(angle / norm_denominator_angle)

    al_semdiff = torch.tensor(al_semdiff)
    al_semdiff_ave = al_semdiff.mean().item()

    score_coe_r = al_repdiff_ave - al_semdiff_ave

    x_list = al_repdiff_norm * torch.cos(al_semdiff)
    y_list = al_repdiff_norm * torch.sin(al_semdiff)
    x_ave = x_list.mean().item()
    y_ave = y_list.mean().item()
    score_coe_c = math.sqrt(x_ave ** 2 + y_ave ** 2)

    return {
        "score_coe_mag": al_repdiff_ave,
        "score_coe_ang": al_semdiff_ave,
        "score_coe_r": score_coe_r,
        "score_coe_c": score_coe_c,
    }

with open(jsonl_path, "r") as fin, open(output_path, "w") as fout:
    for line in tqdm(fin, desc="Processing"):
        record = json.loads(line)
        idx = record["id"]
        pt_file = pt_dir / f"GSM8K_{idx}.pt"
        coe_file = pt_dir_2 / f"GSM8K_{idx}.pt"
        if not coe_file.exists():
            print(f"[!] coe Missing .pt file for id={idx}")
            continue
        if not pt_file.exists():
            print(f"[!] pt Missing .pt file for id={idx}")
            continue
        try:
            ppl= record["ppl"]
            entropy_id= record["entropy"]
            maxprob= record["maxprob"]
            hidden = torch.load(pt_file, map_location="cpu")
            hidden_coe = torch.load(coe_file, map_location="cpu")
            metrics = compute_dynamics_with_scores(hidden,ppl,entropy_id,maxprob)
            record.update(metrics)
            coe_metrics = compute_coe_metrics(hidden_coe)
            record.update(coe_metrics)
            fout.write(json.dumps(record) + "\n")
        except Exception as e:
            print(f"[!] Error at {idx}: {e}")
