import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
import numpy as np
import os
import pandas as pd
from tqdm import tqdm
import cv2


MT_COND_MAX = np.array([1200, 1400, 20, 60, 60, 1800, 0.3])
MT_COND_MIN = np.array([500, 800, 1.5, 4, 25, 1200, 0.04])

MT_OM_COND_MEAN = np.array([946.2, 1042.3, 12.47, 34.27, 38.53, 1417, 0.1588])
MT_OM_COND_STD = np.array([77.97, 79.29, 3.914, 12.32, 5.298, 72.65, 0.04888])
MT_SEM_COND_MEAN = np.array([938.9, 1034, 12.36, 34.12, 38.74, 1417, 0.1584])
MT_SEM_COND_STD = np.array([95.39, 84.13, 4.105, 11.06, 5.306, 73.36, 0.4698])

MT_OM_IMAGE_MEAN = np.array([0.577, 0.577, 0.577])
MT_OM_IMAGE_STD = np.array([0.241, 0.241, 0.241])
MT_SEM_IMAGE_MEAN = np.array([0.523, 0.523, 0.523])
MT_SEM_IMAGE_STD = np.array([0.195, 0.195, 0.195])


class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.encoder_layers = list(self.encoder.children())[:-1]
        self.out_layer = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 7)
        )

    def forward(self, x):
        B = x.shape[0]
        for layer in self.encoder_layers:
            x = layer(x)
        x = x.view(B, -1)
        x = self.out_layer(x)
        return x


def CLAHE(image):
    image = np.array(image).astype(np.uint8)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    result = clahe.apply(image)
    return result


def predict(dataset_type: str, image_list: Image, metric: Predictor, image_size = 256, batch_size=32):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    num_images = len(image_list)
    metric.eval()
    metric = metric.to(device)
    pred_cond_array = []
    for bid in tqdm(range((num_images-1) // batch_size + 1)):
        image_batch = image_list[bid*batch_size:min((bid+1)*batch_size, num_images)]
        image_batch = [
            Image.open(image_path).convert("RGB") for image_path in image_batch
        ]

        # resize
        if image_size is not None:
            image_batch = [
                image.resize((image_size, image_size)) for image in image_batch
            ]
        
        # adjust intensity distribution
        image_batch = [CLAHE(image) for image in image_batch]
        image_batch = [np.stack([image, image, image], axis=-1) for image in image_batch]

        # normalize
        image_batch = [np.array(image).astype(np.float32) for image in image_batch]
        image_batch = np.stack(image_batch, axis=0)
        image_batch = image_batch / 255
        if dataset_type == 'om':
            image_batch = (image_batch - MT_OM_IMAGE_MEAN) / MT_OM_IMAGE_STD
        elif dataset_type == 'sem':
            image_batch = (image_batch - MT_SEM_IMAGE_MEAN) / MT_SEM_IMAGE_STD
        image_batch = np.transpose(image_batch, [0, 3, 1, 2])
        image_batch = torch.tensor(image_batch).to(dtype=torch.float32, device=device)

        with torch.no_grad():
            pred_cond = metric(image_batch)
            pred_cond = pred_cond.to('cpu').numpy()
            pred_cond_array.append(pred_cond)
    pred_cond_array = np.concatenate(pred_cond_array, axis=0)
    return pred_cond_array


def get_metrics(pred_image_list, gt_cond_list, metric_path: str, thresholds = [0.1, 0.2, 0.3], pred_save_path: str = None, load_pred_from: str = None):
    # load model
    metric = Predictor()
    metric.load_state_dict(torch.load(metric_path))

    # normalize conditions
    for cid, cond in enumerate(gt_cond_list):
        if 'om' in metric_path:
            cond = (cond - MT_OM_COND_MEAN) / MT_OM_COND_STD
            cond = (cond - MT_COND_MIN) / (MT_COND_MAX - MT_COND_MIN)
        elif 'sem' in metric_path:
            cond = (cond - MT_SEM_COND_MEAN) / MT_SEM_COND_STD
            cond = (cond - MT_COND_MIN) / (MT_COND_MAX - MT_COND_MIN)
        gt_cond_list[cid] = cond
    
    gt_cond_array = np.stack(gt_cond_list, axis=0)
    
    if load_pred_from is not None:
        pred_cond_array = np.load(load_pred_from)
    else:
        if 'om' in metric_path:
            pred_cond_array = predict('om', pred_image_list, metric)
        elif 'sem' in metric_path:
            pred_cond_array = predict('sem', pred_image_list, metric)
    
    if pred_save_path is not None:
        pardir = os.path.dirname(pred_save_path)
        if not os.access(pardir, os.F_OK):
            os.makedirs(pardir)
        np.save(pred_save_path, pred_cond_array)

    cond_diff = np.abs(pred_cond_array - gt_cond_array)
    print(cond_diff[[0,500,1000,1500,2000,2500,3000,3500,4000],:])
    #print(cond_diff[[0,500,1000,1500,2000,2500,3000,3500,4000],:])
    acc_dict = dict()
    for threshold in thresholds:
        print(np.all(cond_diff < threshold, axis=1)[[0,500,1000,1500,2000,2500,3000,3500,4000]])
        acc_dict[threshold] = np.all(cond_diff < threshold, axis=1).sum() / cond_diff.shape[0]
    return acc_dict


if __name__ == '__main__':
    pred_root = '~/tai/sdcopy/logs/eval/tai-sem-ldm-vq-f8-intra-10000'
    raw_ann_path = '~/tai/sdcopy/tai_data/annotation.csv'
    metric_path = '~/tai/sdcopy/metric/sem_metric.pth'
    
    n_samples_per_cond = 1 # 100
    raw_ann = pd.read_csv(raw_ann_path).iloc[:, -7:]
    gt_cond_list = []
    for i in range(raw_ann.shape[0]):
        for _ in range(n_samples_per_cond):
            gt_cond_list.append(raw_ann.iloc[i].to_numpy())
    
    pred_image_list = os.listdir(pred_root)
    pred_image_list.sort(reverse=False)
    pred_image_list = [
        os.path.join(pred_root, pred_image_path)
        for pred_image_path in pred_image_list
    ]

    acc_dict = get_metrics(pred_image_list[:100], gt_cond_list[:100], metric_path=metric_path, pred_save_path='~/Project/TitaniumDiff/logs/pred/metric_test.npy')
    print(acc_dict)