import os
import json
import numpy as np
from onnx import save
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from pycocotools import mask as mask_utils
import cv2
from region_features.region_utils import show_image
from transformers import AutoTokenizer

path1 = "checkpoints/region-llava-debug/logs.json"
path2 = "checkpoints/llava-debug/logs.json"
source = "playground/data/LLaVA/LLaVA-Pretrain/blip_laion_cc_sbu_8k.json"
img_folder = "./playground/data/LLaVA/LLaVA-Pretrain/images"
mask_folder = "./playground/data/regions/LLaVA-Pretrain/regions-mixed"

logits1:np.ndarray = np.load("checkpoints/region-llava-debug/logits.npy")
logits2:np.ndarray = np.load("checkpoints/llava-debug/logits.npy")
print(np.isnan(logits1).any(), np.isnan(logits2).any())
print(np.isinf(logits1).any(), np.isinf(logits2).any())
logits_diff = logits1 - logits2
sorted_idx = np.argsort(logits_diff)
tokenizer = AutoTokenizer.from_pretrained("liuhaotian/llava-v1.5-7b")
print([tokenizer.decode(i) for i in sorted_idx[:10]])
print(sorted_idx[:10])
print(logits_diff[sorted_idx[:10]])
print([tokenizer.decode(i) for i in sorted_idx[-10:]])
print(sorted_idx[-10:])
print(logits_diff[sorted_idx[-10:]])


data1 = json.load(open(path1))
data2 = json.load(open(path2))
img_data = json.load(open(source))

loss_diff = [l1 - l2 for l1, l2 in zip(data1["loss"], data2["loss"])]
regions = [r[0] for r in data1["regions"]]
idx = data1["idx"]

# get outlier
sorted_idx = np.argsort(loss_diff)

select_idx = [1,2,3,4,5,100,150,200,250,300,350,400,450,500,600,700,800,900,1000]
save_dir = "./playground/analysis/best"
os.makedirs(save_dir, exist_ok=True)
outliers = []
print("Best outliers: ")
for o in select_idx:
    i = sorted_idx[o-1]
    data = img_data[idx[i]]
    data["loss_diff"] = loss_diff[i]
    data["order"] = o
    data["region"] = regions[i]
    outliers.append(data)
    img  = cv2.imread(os.path.join(img_folder, data["image"]))
    cv2.imwrite(os.path.join(save_dir, f"img_{o}.png"), img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    masks_file = os.path.join(mask_folder, os.path.splitext(data["image"])[0] + ".json")
    masks = json.load(open(masks_file))
    for mask in masks:
        mask["segmentation"] = mask_utils.decode(mask["segmentation"]).astype(bool)
    show_image(img, masks, os.path.join(save_dir, f"mask_{o}.png"))
with open(os.path.join(save_dir, "outliers.json"), "w") as f:
    json.dump(outliers, f, indent=4)

select_idx = [1,2,3,4,5,100,150,200,250,300,350,400,450,500,600,700,800,900,1000]
save_dir = "./playground/analysis/worst"
os.makedirs(save_dir, exist_ok=True)
outliers = []
print("Worst outliers: ")
for o in select_idx:
    i = sorted_idx[-o]
    data = img_data[idx[i]]
    data["loss_diff"] = loss_diff[i]
    data["order"] = o
    data["region"] = regions[i]
    outliers.append(data)
    img  = cv2.imread(os.path.join(img_folder, data["image"]))
    cv2.imwrite(os.path.join(save_dir, f"img_{o}.png"), img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    masks_file = os.path.join(mask_folder, os.path.splitext(data["image"])[0] + ".json")
    masks = json.load(open(masks_file))
    for mask in masks:
        mask["segmentation"] = mask_utils.decode(mask["segmentation"]).astype(bool)
    show_image(img, masks, os.path.join(save_dir, f"mask_{o}.png"))
with open(os.path.join(save_dir, "outliers.json"), "w") as f:
    json.dump(outliers, f, indent=4)

# visualize
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 9))
plt.scatter(regions, loss_diff, s=5)
plt.xlabel('#Regions')
plt.ylabel('Loss Difference')
plt.title('Loss Difference by #Regions')

# linear regression
X = np.array(regions).reshape(-1, 1)
y = np.array(loss_diff).reshape(-1, 1)
reg = LinearRegression().fit(X, y)
print(reg.coef_, reg.intercept_)
print("Expected #Regions for 0 loss difference: ", -reg.intercept_[0]/reg.coef_[0][0])
plt.plot(X, reg.predict(X), color='red')

# Calculate the Pearson correlation coefficient
correlation, p_value = pearsonr(X.reshape(-1), y.reshape(-1))
print(f"Pearson correlation coefficient: {correlation}")
print(f"P-value: {p_value}")

# save fig
plt.tight_layout()
plt.savefig("./playground/analysis/loss_diff.png")