import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
import numpy as np
import os
import pandas as pd
from tqdm import tqdm


MT_COND_MAX = np.array([1200, 1400, 20, 60, 60, 1800, 0.3])
MT_COND_MIN = np.array([500, 800, 1.5, 4, 25, 1200, 0.04])

MT_IMAGE_MEAN = np.array([0.485, 0.456, 0.406])
MT_IMAGE_STD = np.array([0.229, 0.224, 0.225])


class Predictor(nn.Module):
    def __init__(self):
        super(Predictor, self).__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.encoder_layers = list(self.encoder.children())[:-1]
        self.out_layer = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 7)
        )

    def forward(self, x):
        B = x.shape[0]
        for layer in self.encoder_layers:
            x = layer(x)
        x = x.view(B, -1)
        x = self.out_layer(x)
        return x


def predict(image_list: Image, metric: Predictor, image_size = 256, batch_size=32):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    num_images = len(image_list)
    metric.eval()
    metric = metric.to(device)
    pred_cond_array = []
    for bid in tqdm(range((num_images-1) // batch_size + 1)):
        image_batch = image_list[bid*batch_size:min((bid+1)*batch_size, num_images)]
        image_batch = [
            Image.open(image_path).convert("RGB") for image_path in image_batch
        ]

        # resize
        if image_size is not None:
            image_batch = [
                image.resize((image_size, image_size)) for image in image_batch
            ]
        
        # normalize
        image_batch = [np.array(image).astype(np.float32) for image in image_batch]
        image_batch = np.stack(image_batch, axis=0)
        image_batch = image_batch / 255
        image_batch = (image_batch - MT_IMAGE_MEAN) / MT_IMAGE_STD
        image_batch = np.transpose(image_batch, [0, 3, 1, 2])
        image_batch = torch.tensor(image_batch).to(dtype=torch.float32, device=device)
        #print(image_batch[0])
        with torch.no_grad():
            pred_cond = metric(image_batch)
            pred_cond = pred_cond.to('cpu').numpy()
            pred_cond_array.append(pred_cond)
    pred_cond_array = np.concatenate(pred_cond_array, axis=0)
    return pred_cond_array


def get_metrics(pred_image_list, gt_cond_list, metric_path: str, thresholds = [0.1, 0.2, 0.3], pred_save_path: str = None, load_pred_from: str = None):
    # load model
    metric = Predictor()
    metric.load_state_dict(torch.load(metric_path))

    # normalize conditions
    for cid, cond in enumerate(gt_cond_list):
        cond = (cond - MT_COND_MIN) / (MT_COND_MAX - MT_COND_MIN)
        gt_cond_list[cid] = cond
    
    gt_cond_array = np.stack(gt_cond_list, axis=0)
    
    if load_pred_from is not None:
        pred_cond_array = np.load(load_pred_from)
    else:
        pred_cond_array = predict(pred_image_list, metric)
    
    if pred_save_path is not None:
        pardir = os.path.dirname(pred_save_path)
        if not os.access(pardir, os.F_OK):
            os.makedirs(pardir)
        np.save(pred_save_path, pred_cond_array)
    print(pred_cond_array[:10])
    print(gt_cond_array[:10])
    cond_diff = np.abs(pred_cond_array - gt_cond_array)
    acc_dict = dict()
    for threshold in thresholds:
        acc_dict[threshold] = np.all(cond_diff < threshold, axis=1).sum() / cond_diff.shape[0]
    return acc_dict


if __name__ == '__main__':
    pred_root = '~/Project/TitaniumDiff/logs/eval/tai-sem-ldm-vq-f8'
    raw_ann_path = '~/Project/TitaniumDiff/tai_data/sum1126.csv'
    metric_path = '~/Project/TitaniumDiff/metric/sem_metric.pth'
    
    n_samples_per_cond = 100
    raw_ann = pd.read_csv(raw_ann_path).iloc[:, -7:]
    gt_cond_list = []
    for i in range(raw_ann.shape[0]):
        for _ in range(n_samples_per_cond):
            gt_cond_list.append(raw_ann.iloc[i].to_numpy())
    
    pred_image_list = os.listdir(pred_root)
    pred_image_list.sort(reverse=False)
    pred_image_list = [
        os.path.join(pred_root, pred_image_path)
        for pred_image_path in pred_image_list
    ]

    acc_dict = get_metrics(pred_image_list[:3000], gt_cond_list[:3000], metric_path=metric_path, pred_save_path='~/Project/TitaniumDiff/logs/pred/metric_test.npy')
    print(acc_dict)