import os, cv2, argparse
import torch
from iwssim import *
from modules import *
from old_models import *
import numpy as np
from time import time

from tqdm import tqdm

from torch.utils.data import DataLoader
from model import model_eval
from glob import glob
import pandas as pd
import matplotlib.pyplot as plt
# from iqadataset import load_dataset_pytorch, load_dataset
device = torch.device("cuda:3")
from default_datasets import *
import multiprocessing 
from torchvision.models import resnet18
from get_model import get_model # type: ignore

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)



def model_eval(model, device, dataset_args, dataset='LIVE'):
    training_data = IQADatasetPyTorch(dataset, device=device, args=dataset_args)
    dataloader = DataLoader(training_data, batch_size=1, shuffle=True)    
    print("evaluating correlations")
    if dataset_args['saliency']:
        df = []
        for i, (sample, scores) in enumerate(tqdm(dataloader)):
            # print(scores.shape)
            if dataset_args['dct']:
                maps, dct = sample
                predicted_scores = model(maps, dct)
            elif dataset_args['saliency']:
                maps, sal = sample
                predicted_scores = model(maps, sal)
            else:
                predicted_scores = model(sample)

            df.append([scores.item(), predicted_scores.item()])
        df = pd.DataFrame(df)
        c1 = df.corr(method='pearson')[0][1].item()
        c2 = df.corr(method='spearman')[0][1].item()
        print(c1, c2)
        return c1, c2
    with torch.no_grad():
        df = []
        for i, (sample, scores) in enumerate(tqdm(dataloader)):
            # print(scores.shape)
            if dataset_args['dct']:
                maps, dct = sample
                predicted_scores = model(maps, dct)
            elif dataset_args['saliency']:
                maps, sal = sample
                predicted_scores = model(maps, sal)
            else:
                predicted_scores = model(sample)

            df.append([scores.item(), predicted_scores.item()])
        df = pd.DataFrame(df)
        c1 = df.corr(method='pearson')[0][1].item()
        c2 = df.corr(method='spearman')[0][1].item()
        print(c1, c2)
        return c1, c2


if __name__ == "__main__":
    multiprocessing.set_start_method("spawn")

    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True, help='Name of the model')
    # parser.add_argument('--dataset', type=str, help='Name of the dataset')
    parser.add_argument('--epoch', type=int, required=True)
    parser.add_argument('--device', type=str, required=True)

    args = parser.parse_args()

    device = torch.device(args.device)
    device_id = int(args.device[-1])

    c_path = glob(f'models/{args.model}/*{args.epoch}*')[0]

    dataset_args = {'return_path': False, 
                    'simple': False,
                    'flip': True, 
                    'vif': False, 
                    'dlm': False, 
                    'dct': False, 
                    'saliency_model': None, 
                    'saliency': False, 
                    'lbp': False,
                    'part': 'train'}

    model, dataset_args = get_model(args.model, dataset_args, device)
    
    print(type(model).__name__)
    model = nn.DataParallel(model, device_ids=[device_id], output_device=device_id)
    # model.to(device)
    # print(c_path)
    model.load_state_dict(torch.load(c_path))
    # state_dict = torch.load(c_path,map_location='cpu')
    # model.load_state_dict(state_dict)
    model.to(device)


    # model.eval();
    # for m in model.modules():
    #     if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
    #         m.track_runing_stats=False


    processes = []
    # datasets = ['LIVE', 'TID2013', 'PIPAL', 'CSIQ']
    datasets = ['LIVE', 'TID2013', 'CSIQ']
    # datasets = ['LIVE', 'CSIQ']
    with multiprocessing.Pool() as pool:
        results = pool.starmap(model_eval, [(model, device, dataset_args, dataset) for dataset in datasets])

    df = pd.read_csv('runs/logs.csv')
    for i, dataset in enumerate(datasets):
        new_score = {'model': args.model, 'dataset': dataset, 'epoch': args.epoch, 'PLCC': results[i][0], "SROCC": results[i][1]}
        new_score = pd.DataFrame([new_score])
        df = pd.concat([df, new_score])
    df.to_csv('runs/logs.csv', index=False)


# python evaluate_model.py  --model 'gatedsalbpiqa try0 KADID-10k' --epoch 6 --device cuda:0
# python evaluate_model.py  --model 'lbpiqa try0 KADID-10k' --epoch 27 --device cuda:4
# python evaluate_model.py  --model 'gatedlbpiqa2 try0 KADID-10k' --epoch 3 --device cuda:0