'''                                        
Copyright 2024 Image Processing Research Group of University Federico
II of Naples ('GRIP-UNINA'). All rights reserved.
                        
Licensed under the Apache License, Version 2.0 (the "License");       
you may not use this file except in compliance with the License. 
You may obtain a copy of the License at                    
                                           
    http://www.apache.org/licenses/LICENSE-2.0
                                                      
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,    
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.                         
See the License for the specific language governing permissions and
limitations under the License.
''' 

import torch
import os
import pandas
import numpy as np
import tqdm
import glob
import sys
import yaml
from PIL import Image
import random
import io

from torchvision.transforms  import CenterCrop, Resize, Compose, InterpolationMode
from utils.processing import make_normalize
from utils.fusion import apply_fusion
from networks import create_architecture, load_weights

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


def get_config(model_name, weights_dir='./weights'):
    with open(os.path.join(weights_dir, model_name, 'config.yaml')) as fid:
        data = yaml.load(fid, Loader=yaml.FullLoader)
    model_path = os.path.join(weights_dir, model_name, data['weights_file'])
    return data['model_name'], model_path, data['arch'], data['norm_type'], data['patch_size']

def jpeg_compress_pil(img, quality_range=(65, 100)):
    """Compress a PIL image with a random JPEG quality."""
    buffer = io.BytesIO()
    quality = random.randint(*quality_range)
    img.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)
    return Image.open(buffer)

def runnig_tests(input_csv, weights_dir, models_list, device, batch_size = 64):
    table = pandas.read_csv(input_csv)[['file_name',]]
    rootdataset = os.path.dirname(os.path.abspath(input_csv))
    
    models_dict = dict()
    transform_dict = dict()
    print("Models:")
    for model_name in models_list:
        print(model_name, flush=True)
        _, model_path, arch, norm_type, patch_size = get_config(model_name, weights_dir=weights_dir)

        model = load_weights(create_architecture(arch), model_path)
        model = model.to(device).eval()

        transform = list()
        if patch_size is None:
            print('input none', flush=True)
            transform_key = 'none_%s' % norm_type
        elif patch_size=='Clip224':
            print('input resize:', 'Clip224', flush=True)
            transform.append(Resize(224, interpolation=InterpolationMode.BICUBIC))
            transform.append(CenterCrop((224, 224)))
            transform_key = 'Clip224_%s' % norm_type
        elif isinstance(patch_size, tuple) or isinstance(patch_size, list):
            print('input resize:', patch_size, flush=True)
            transform.append(Resize(*patch_size))
            transform.append(CenterCrop(patch_size[0]))
            transform_key = 'res%d_%s' % (patch_size[0], norm_type)
        elif patch_size > 0:
            print('input crop:', patch_size, flush=True)
            transform.append(CenterCrop(patch_size))
            transform_key = 'crop%d_%s' % (patch_size, norm_type)
        
        transform.append(make_normalize(norm_type))
        transform = Compose(transform)
        transform_dict[transform_key] = transform
        models_dict[model_name] = (transform_key, model)
        print(flush=True)

    ### test
    with torch.no_grad():
        
        do_models = list(models_dict.keys())
        do_transforms = set([models_dict[_][0] for _ in do_models])
        print(do_models)
        print(do_transforms)
        print(flush=True)
        
        print("Running the Tests")
        batch_img = {k: list() for k in transform_dict}
        batch_id = list()
        last_index = table.index[-1]
        for index in tqdm.tqdm(table.index, total=len(table)):
            file_name = os.path.join(rootdataset, table.loc[index, 'file_name'])
            for k in transform_dict:
                batch_img[k].append(transform_dict[k](Image.open(file_name).convert('RGB')))
            batch_id.append(index)

            if (len(batch_id) >= batch_size) or (index==last_index):
                for k in do_transforms:
                    batch_img[k] = torch.stack(batch_img[k], 0)

                for model_name in do_models:
                    out_tens = models_dict[model_name][1](batch_img[models_dict[model_name][0]].clone().to(device)).cpu().numpy()

                    if out_tens.shape[1] == 1:
                        out_tens = out_tens[:, 0]
                    elif out_tens.shape[1] == 2:
                        out_tens = out_tens[:, 1] - out_tens[:, 0]
                    else:
                        assert False
                    
                    if len(out_tens.shape) > 1:
                        logit1 = np.mean(out_tens, (1, 2))
                    else:
                        logit1 = out_tens

                    for ii, logit in zip(batch_id, logit1):
                        table.loc[ii, model_name] = logit

                batch_img = {k: list() for k in transform_dict}
                batch_id = list()

            assert len(batch_id)==0
        
    return table


def runnig_tests_dw2(input_csv, weights_dir, models_list, device, batch_size=64):
    table = pandas.read_csv(input_csv)[['file_name']]
    rootdataset = os.path.dirname(os.path.abspath(input_csv))

    models_dict = dict()
    transform_dict = dict()
    print("Models:")
    for model_name in models_list:
        print(model_name, flush=True)
        _, model_path, arch, norm_type, patch_size = get_config(model_name, weights_dir=weights_dir)

        model = load_weights(create_architecture(arch), model_path)
        model = model.to(device).eval()

        transform = list()
        if patch_size is None:
            print('input none', flush=True)
            transform_key = f'none_{norm_type}'
        elif patch_size == 'Clip224':
            print('input resize: Clip224', flush=True)
            transform.append(Resize(224, interpolation=InterpolationMode.BICUBIC))
            transform.append(CenterCrop((224, 224)))
            transform_key = f'Clip224_{norm_type}'
        elif isinstance(patch_size, (tuple, list)):
            print('input resize:', patch_size, flush=True)
            transform.append(Resize(*patch_size))
            transform.append(CenterCrop(patch_size[0]))
            transform_key = f'res{patch_size[0]}_{norm_type}'
        elif patch_size > 0:
            print('input crop:', patch_size, flush=True)
            transform.append(CenterCrop(patch_size))
            transform_key = f'crop{patch_size}_{norm_type}'

        transform.append(make_normalize(norm_type))
        transform = Compose(transform)
        transform_dict[transform_key] = transform
        models_dict[model_name] = (transform_key, model)
        print(flush=True)

    with torch.no_grad():
        do_models = list(models_dict.keys())
        do_transforms = set(models_dict[_][0] for _ in do_models)
        print(do_models)
        print(do_transforms)
        print("Running the Tests", flush=True)

        batch_img = {k: [] for k in transform_dict}
        batch_id = []
        last_index = table.index[-1]

        for index in tqdm.tqdm(table.index, total=len(table)):
            file_name = os.path.join(rootdataset, table.loc[index, 'file_name'])

            # Load and degrade image safely
            try:
                img = Image.open(file_name).convert('RGB')
                if max(img.size) > 1024:
                    img.thumbnail((1024, 1024), Image.BILINEAR)

                # Random crop (scale from 5/8 to full size)
                w, h = img.size
                crop_scale = random.uniform(0.625, 1.0)
                crop_w, crop_h = int(w * crop_scale), int(h * crop_scale)
                left = random.randint(0, w - crop_w) if w > crop_w else 0
                top = random.randint(0, h - crop_h) if h > crop_h else 0
                img = img.crop((left, top, left + crop_w, top + crop_h))

                # Resize to 200x200
                img = img.resize((200, 200), Image.BILINEAR)

                # JPEG compression
                img = jpeg_compress_pil(img, quality_range=(65, 100))

            except Exception as e:
                print(f"⚠️ Skipping {file_name} due to error: {e}")
                continue

            for k in transform_dict:
                batch_img[k].append(transform_dict[k](img))
            batch_id.append(index)

            if len(batch_id) >= batch_size or index == last_index:
                for k in do_transforms:
                    batch_img[k] = torch.stack(batch_img[k], 0)

                for model_name in do_models:
                    transform_key, model = models_dict[model_name]
                    inputs = batch_img[transform_key].clone().to(device)
                    out_tens = model(inputs).cpu().numpy()

                    if out_tens.shape[1] == 1:
                        out_tens = out_tens[:, 0]
                    elif out_tens.shape[1] == 2:
                        out_tens = out_tens[:, 1] - out_tens[:, 0]
                    else:
                        raise ValueError("Unexpected model output shape")

                    logits = np.mean(out_tens, axis=(1, 2)) if len(out_tens.shape) > 1 else out_tens

                    for ii, logit in zip(batch_id, logits):
                        table.loc[ii, model_name] = logit

                batch_img = {k: [] for k in transform_dict}
                batch_id = []

            assert len(batch_id) == 0

    return table


from torch.utils.data import Dataset, DataLoader
from functools import partial

class CSVImageDataset(Dataset):
    """Loads PIL images from a CSV list and returns (index, PIL_image)."""
    def __init__(self, table, root_dir, early_max=1024):
        self.table = table
        self.root_dir = root_dir
        self.early_max = early_max

    def __len__(self):
        return len(self.table)

    def __getitem__(self, idx):
        file_rel = self.table.loc[idx, 'file_name']
        file_path = os.path.join(self.root_dir, file_rel)
        img = Image.open(file_path).convert('RGB')
        if max(img.size) > self.early_max:
            img.thumbnail((self.early_max, self.early_max), Image.BILINEAR)
        return idx, img


from torchvision import transforms
from torchvision.transforms import InterpolationMode

FORCE_SIZE = 224
_force_resize = transforms.Resize((FORCE_SIZE, FORCE_SIZE), InterpolationMode.BICUBIC)

def collate_and_transform(batch, transform_dict):
    """
    batch: list of (idx, PIL_image)
    return:
        indices: list[int]
        batch_tensors: dict[str, torch.Tensor]
    """
    indices = [b[0] for b in batch]
    images  = [b[1] for b in batch]

    # 1) Resize all once to guarantee same size
    images = [_force_resize(img) for img in images]

    out = {}
    for k, tfm in transform_dict.items():
        tens_list = [tfm(img) for img in images]
        # Sanity check (optional)
        # assert len({t.shape for t in tens_list}) == 1, "Transform did not produce equal sizes"
        out[k] = torch.stack(tens_list, dim=0)

    return indices, out


def runnig_tests_dw(input_csv, weights_dir, models_list, device, batch_size=64, num_workers=4):
    cols_needed = ["file_name", "label", "model"]
    table = pandas.read_csv(input_csv)[cols_needed]
    # rootdataset default
    rootdataset = os.path.join(os.environ.get('SCRATCH', ''), "OpenFake")

    # ---- Build models and transforms ----
    models_dict = {}
    transform_dict = {}
    print("Models:")
    for model_name in models_list:
        print(model_name, flush=True)
        _, model_path, arch, norm_type, patch_size = get_config(model_name, weights_dir=weights_dir)

        model = load_weights(create_architecture(arch), model_path)
        model = model.to(device).eval()

        t_list = []
        if patch_size is None:
            print('input none', flush=True)
            transform_key = f'none_{norm_type}'
        elif patch_size == 'Clip224':
            print('input resize: Clip224', flush=True)
            t_list.append(Resize(224, interpolation=InterpolationMode.BICUBIC))
            t_list.append(CenterCrop((224, 224)))
            transform_key = f'Clip224_{norm_type}'
        elif isinstance(patch_size, (tuple, list)):
            print('input resize:', patch_size, flush=True)
            t_list.append(Resize(*patch_size))
            t_list.append(CenterCrop(patch_size[0]))
            transform_key = f'res{patch_size[0]}_{norm_type}'
        elif patch_size > 0:
            print('input crop:', patch_size, flush=True)
            t_list.append(CenterCrop(patch_size))
            transform_key = f'crop{patch_size}_{norm_type}'
        else:
            raise ValueError("Unsupported patch_size type/value")

        t_list.append(make_normalize(norm_type))
        transform = Compose(t_list)

        # Reuse transforms if identical
        if transform_key not in transform_dict:
            transform_dict[transform_key] = transform

        models_dict[model_name] = (transform_key, model)
        print(flush=True)

    # ---- Dataset & DataLoader ----
    ds = CSVImageDataset(table, rootdataset, early_max=1024)
    collate_fn = partial(collate_and_transform, transform_dict=transform_dict)
    loader = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=collate_fn
    )

    # ---- Run inference ----
    do_models = list(models_dict.keys())
    do_transforms = set(models_dict[m][0] for m in do_models)
    print(do_models)
    print(do_transforms)
    print("Running the Tests", flush=True)

    with torch.no_grad():
        for indices, batch_tensors in tqdm.tqdm(loader, total=len(ds)//batch_size + 1):
            for model_name in do_models:
                transform_key, model = models_dict[model_name]
                inputs = batch_tensors[transform_key].to(device, non_blocking=True)
                out_tens = model(inputs).cpu().numpy()

                # Match your original post-processing
                if out_tens.shape[1] == 1:
                    out_tens = out_tens[:, 0]
                elif out_tens.shape[1] == 2:
                    out_tens = out_tens[:, 1] - out_tens[:, 0]
                else:
                    raise ValueError("Unexpected model output shape")

                if len(out_tens.shape) > 1:
                    logit1 = np.mean(out_tens, (1, 2))
                else:
                    logit1 = out_tens

                for ii, logit in zip(indices, logit1):
                    table.loc[ii, model_name] = logit

    return table


if __name__ == "__main__":
    
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--in_csv"     , '-i', type=str, help="The path of the input csv file with the list of images")
    parser.add_argument("--out_csv"    , '-o', type=str, help="The path of the output csv file", default="./results.csv")
    parser.add_argument("--weights_dir", '-w', type=str, help="The directory to the networks weights", default="./weights")
    parser.add_argument("--models"     , '-m', type=str, help="List of models to test", default='clipdet_latent10k_plus,Corvi2023')
    parser.add_argument("--fusion"     , '-f', type=str, help="Fusion function", default='soft_or_prob')
    parser.add_argument("--device"     , '-d', type=str, help="Torch device", default='cuda:0')
    args = vars(parser.parse_args())
    
    if args['models'] is None:
        args['models'] = os.listdir(args['weights_dir'])
    else:
        args['models'] = args['models'].split(',')
    
    # table = runnig_tests_dw(
    #     args['in_csv'],
    #     args['weights_dir'],
    #     args['models'],
    #     args['device'],
    #     batch_size=64,         # or whatever you like
    #     num_workers=4
    # )
    # if args['fusion'] is not None:
    #     table['fusion'] = apply_fusion(table[args['models']].values, args['fusion'], axis=-1)
    
    # output_csv = args['out_csv']
    # os.makedirs(os.path.dirname(os.path.abspath(output_csv)), exist_ok=True)
    # table.to_csv(output_csv, index=False)  # save the results as csv file
    
    # from csv load table
    table = pandas.read_csv(args['out_csv'])
    
        # ─── after runnig_tests_dw returns `table` ──────────────────────────────────────
    from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, average_precision_score

    # choose which prediction column to evaluate
    pred_col = "Corvi2023"

    y_true   = (table["label"] == "fake").astype(int)
    y_score  = table[pred_col].values          # real/fake score (logit or prob)

    # if you want a probability, convert logits with sigmoid:
    # y_prob = 1 / (1 + np.exp(-y_score))

    #   · overall metrics (threshold = 0 on logits, change if you prefer)
    y_pred   = (y_score > 0).astype(int)
    auc      = roc_auc_score(y_true, y_score)
    acc      = accuracy_score(y_true, y_pred)
    f1       = f1_score(y_true, y_pred)
    tpr = np.sum((y_true == 1) & (y_pred == 1)) / np.sum(y_true == 1)
    auc_pr = average_precision_score(y_true, y_score)

    print("\n==== OVERALL METRICS ({}) ====".format(pred_col))
    print(f"ROC AUC : {auc:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"F1 score: {f1:.4f}")
    print(f"TPR     : {tpr:.4f}")
    print(f"AUC PR  : {auc_pr:.4f}")

    #   · accuracy per generator (the CSV column 'model')
    print("\n==== PER-GENERATOR ACCURACY ({}) ====".format(pred_col))
    for gen_name, grp in table.groupby("model"):
        gen_true = (grp["label"] == "fake").astype(int)
        gen_pred = (grp[pred_col].values > 0).astype(int)
        gen_acc  = accuracy_score(gen_true, gen_pred)
        print(f"{gen_name:30s}  Acc: {gen_acc:.4f}  (n={len(grp)})")

    # (optional) save the metrics to disk
    table.to_csv(args["out_csv"], index=False)
    with open(os.path.splitext(args["out_csv"])[0] + "_metrics.txt", "w") as f:
        f.write(f"AUC  = {auc:.6f}\n")
        f.write(f"ACC  = {acc:.6f}\n")
        f.write(f"F1   = {f1:.6f}\n")
        for gen_name, grp in table.groupby("model"):
            gen_acc = accuracy_score((grp['label'] == "fake").astype(int),
                                    (grp[pred_col].values > 0).astype(int))
            f.write(f"{gen_name} {gen_acc:.6f}\n")
    
