import torch
import os
import time
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader
from tqdm import tqdm

# other modules
# YAML config reader
from config_parser import get_run_config
# Functions to evaluate reconstruction quality of the codec in clear and adversarial scenarios
from quality_evaluators import FR_COLS, NR_COLS, evaluate_codec_image_quality, evaluate_reference_codec
# Dataset class to iterate over image folder
from dataloaders import ImageFolderDataset, image_folder_collate_fn
# IQA metrics implementations
from metrics import niqe
# Various helper functions
from helpers import apply_attack, apply_codec, save_image, \
     load_attack_params_json, fill_df_metadata
# Functions to calculate final robustness and quality scores for codecs based on raw metric values
from codec_scoring_methods import calc_scores_codec
# All implemented attack objectives
from codec_losses import loss_name_2_func
# A list of column names in the raw data for all statistics collected
from raw_data_scheme import RAW_RESULTS_COLS
# Setup NIC model and possible defence from config
from setup_modules import setup_codec, setup_defence, setup_files, setup_attack_presets

np.random.seed(int(time.time()))


def evaluate_codec(
                run_cfg,
                is_main,
                images,
                attacked_images, 
                torch_seed,
                nr_models,
                image_name,
                fn,
                global_i,
                reconstructed_save_path=None, 
                attacked_save_path=None):
    device = run_cfg['device']
    if is_main:
        model = run_cfg['def_main_codec']
        undefended_model = run_cfg['undef_main_codec']
    else:
        model = run_cfg['def_model']
        undefended_model = run_cfg['undef_model']
    
    with torch.no_grad():
        attacked_images = attacked_images.to(device)
        images = images.to(device)
        print(f'DATA RANGE images: {images.min()}, {images.max()}')
        print(f'DATA RANGE attacked_images: {attacked_images.min()}, {attacked_images.max()}')

        def_clear_outs = apply_codec(images, model, is_main, run_cfg['is_jpegai'], fn, torch_seed, device, run_cfg['output_range'], mainc_save_name='clear')
        def_attacked_outs = apply_codec(attacked_images, model, is_main, run_cfg['is_jpegai'], fn, torch_seed, device, run_cfg['output_range'], mainc_save_name='attacked')
        undef_clear_outs = apply_codec(images, undefended_model, is_main, run_cfg['is_jpegai'], fn, torch_seed, device, run_cfg['output_range'], mainc_save_name='undef_clear')
        undef_attacked_outs = apply_codec(attacked_images, undefended_model, is_main, run_cfg['is_jpegai'], fn, torch_seed, device, run_cfg['output_range'], mainc_save_name='undef_attacked')
        # convert everything to rgb [0,1]
        images = images / run_cfg['input_range']
        attacked_images = attacked_images / run_cfg['input_range']
        img_dict = {
            'clear': images,
            'attacked':attacked_images,
            'rec_def_clear':def_clear_outs['rec_img'],
            'rec_def_attacked':def_attacked_outs['rec_img'],
            'rec_undef_clear':undef_clear_outs['rec_img'],
            'rec_undef_attacked':undef_attacked_outs['rec_img']
        }

        row =  {
        'image_name': Path(image_name).name,
        'codec_name':run_cfg['codec'],
        'loss_name':run_cfg['loss_name'],
        'defence_name':run_cfg['defence_name'],
        # BPPs
        'bpp_defended-clear':def_clear_outs['bpp'],
        'bpp_undefended-clear':undef_clear_outs['bpp'],
        'bpp_defended-attacked':def_attacked_outs['bpp'],
        'bpp_undefended-attacked':undef_attacked_outs['bpp'],
        # real BPPs
        'real-bpp_defended-clear':def_clear_outs['real_bpp'],
        'real-bpp_undefended-clear':undef_clear_outs['real_bpp'],
        'real-bpp_defended-attacked':def_attacked_outs['real_bpp'],
        'real-bpp_undefended-attacked':undef_attacked_outs['real_bpp'],
        #time
        'codec-time_defended-clear':def_clear_outs['codec_time'],
        'codec-time_undefended-clear':undef_clear_outs['codec_time'],
        'codec-time_defended-attacked':def_attacked_outs['codec_time'],
        'codec-time_undefended-attacked':undef_attacked_outs['codec_time'],
        }   


        iqa_vals = evaluate_codec_image_quality(img_dict, nr_models, device)
        row.update(iqa_vals)

        target_quality = None if img_dict['rec_def_clear'] is None else [iqa_vals[f'psnr_clear_defended-rec-clear']]
        if target_quality is not None:
            ref_codec_res, ref_codec_reconstructions = evaluate_reference_codec(img_dict, target_quality, nr_models, def_clear_outs['bpp'], device, run_cfg['dump_path'])
            row.update(ref_codec_res)
        else:
            ref_codec_reconstructions = None
        
    if run_cfg['dump_path'] is not None and run_cfg['batch_size'] == 1 and global_i % run_cfg['dump_freq'] == 0 and def_attacked_outs['rec_img'] is not None:
        if is_main:
            save_image(def_attacked_outs['rec_img'], run_cfg['dump_path'], img_name=Path(fn).stem, img_type='main_rec_att')
        else:
            save_image(attacked_images, run_cfg['dump_path'], img_name=Path(fn).stem, img_type='att')
            save_image(def_attacked_outs['rec_img'], run_cfg['dump_path'], img_name=Path(fn).stem, img_type='rec_att')
            if ref_codec_reconstructions is not None:
                save_image(ref_codec_reconstructions['ref_codec_attacked'], run_cfg['dump_path'], img_name=Path(fn).stem, img_type='rec_att_jpeg')
                save_image(ref_codec_reconstructions['ref_codec_fix_bpp_attacked'], run_cfg['dump_path'], img_name=Path(fn).stem, img_type='rec_att_jpeg_fix')
    
    if attacked_save_path is not None and global_i % run_cfg['save_freq'] == 0 and not is_main:
        save_image(attacked_images, attacked_save_path, img_name=Path(fn).stem, img_type='att')

    if reconstructed_save_path is not None and global_i % run_cfg['save_freq'] == 0 and def_attacked_outs['rec_img'] is not None:
        if is_main:
            save_image(def_attacked_outs['rec_img'], reconstructed_save_path, img_name=Path(fn).stem, img_type='main_rec_att')
        else:
            save_image(def_attacked_outs['rec_img'], reconstructed_save_path, img_name=Path(fn).stem, img_type='rec_att')
            if ref_codec_reconstructions is not None:
                save_image(ref_codec_reconstructions['ref_codec_attacked'], reconstructed_save_path, img_name=Path(fn).stem, img_type='rec_att_jpeg')
                save_image(ref_codec_reconstructions['ref_codec_fix_bpp_attacked'], reconstructed_save_path, img_name=Path(fn).stem, img_type='rec_att_jpeg_fix')

    delta_time = np.nanmean(np.array([def_clear_outs['codec_time'], def_attacked_outs['codec_time'], undef_clear_outs['codec_time'], undef_attacked_outs['codec_time']]))
    return row, delta_time 


def run_robustness_evaluation(run_cfg, 
        dataset,
        attack_params={}, 
        attacked_save_path=None, 
        reconstructed_save_path=None):
    
    nr_models = {}
    nr_models['niqe'] = niqe(run_cfg['device'])

    time_sum = 0
    attack_num = 0
    total_codec_time = 0
    num_codec_evals = 0
    cur_result_df = pd.DataFrame(columns=RAW_RESULTS_COLS)
    main_codec_result_df = None
    main_codec_total_time = None
    num_main_codec_evals = None
    if run_cfg['def_main_codec'] is not None:
        main_codec_result_df = pd.DataFrame(columns=RAW_RESULTS_COLS)
        main_codec_total_time = 0
        num_main_codec_evals = 0
    
    global_i = 0
    for data in tqdm(dataset):
        images = data['images'][0] # batch_size == 1, list of torch tensors of shape [1,3,H,W]
        fn = data['image_paths'][0] # list of strings
        image_name = data['image_names'][0] # list of strings
        images = images.to(run_cfg['device'])

        # random seed equal for both clear and attacked
        torch_seed = np.random.randint(low=0, high=999999)

        # images and attacked_images will have [0,255] range for JPEGAI, [0,1] otherwise
        images = images * run_cfg['input_range']

        attack_result = apply_attack(
            run_cfg['def_model'], # it is the defended model
            run_cfg['attack_callback'],
            images.clone().contiguous(),
            device=run_cfg['device'],
            variable_params=attack_params,
            seed=torch_seed,
            is_jpegai=run_cfg['is_jpegai'],
            loss_func=loss_name_2_func[run_cfg['loss_name']],
            loss_func_name=run_cfg['loss_name'],
            )
        if attack_result is None:
            raise ValueError(f'Attack failed on image {fn}')
        
        attacked_images, attack_time = attack_result

        time_sum += attack_time
        attack_num += 1

        row, delta_time = evaluate_codec(
                                        run_cfg=run_cfg, is_main=False, images=images.clone(), attacked_images=attacked_images.clone(),
                                        torch_seed=torch_seed, nr_models=nr_models, image_name=image_name, fn=fn, global_i=global_i,
                                        reconstructed_save_path=reconstructed_save_path, attacked_save_path=attacked_save_path)
        total_codec_time += delta_time
        num_codec_evals += 1
        cur_result_df.loc[len(cur_result_df)] = row

        if run_cfg['def_main_codec'] is not None:
            main_codec_row, delta_time_main = evaluate_codec(
                                        run_cfg=run_cfg, is_main=True, images=images.clone(), attacked_images=attacked_images.clone(),
                                        torch_seed=torch_seed, nr_models=nr_models, image_name=image_name, fn=fn, global_i=global_i,
                                        reconstructed_save_path=reconstructed_save_path, attacked_save_path=attacked_save_path)
            main_codec_result_df.loc[len(main_codec_result_df)] = main_codec_row
            main_codec_total_time += delta_time_main
            num_main_codec_evals += 1

        global_i += run_cfg['batch_size']

    if attack_num == 0:
        return None
    mean_time_mainc = None if main_codec_total_time is None else main_codec_total_time / num_main_codec_evals
    return cur_result_df, total_codec_time / num_codec_evals, time_sum, main_codec_result_df, mean_time_mainc


def test_main(attack_callback):
    # load all params
    run_cfg = get_run_config()
    run_cfg = setup_codec(run_cfg)
    run_cfg = setup_defence(run_cfg)
    setup_files(run_cfg)
    run_cfg['attack_callback'] = attack_callback
    list_of_presets = setup_attack_presets(run_cfg)
    print(f'Using defence: {run_cfg["defence_name"]}')
    
    
    raw_results_df = pd.DataFrame(columns=RAW_RESULTS_COLS)
    main_codec_full_raw_df = pd.DataFrame(columns=RAW_RESULTS_COLS)
    scores_df = pd.DataFrame(columns=['attack','attack_preset', 'defence_preset', 'score', 'value'])
    scores_main_codec_df = pd.DataFrame(columns=['attack','attack_preset', 'defence_preset', 'score', 'value'])

    mean_codec_times = []
    mean_main_codec_times = []

    total_attack_time = 0
    total_attack_calls = 0

    for cur_preset in list_of_presets:
        print(f'======== Current Preset: {cur_preset} ========')
        # load attack configs
        if cur_preset != -1:
            attack_params = load_attack_params_json(cur_preset, run_cfg["attack"])
        else:
            attack_params = {}
        print(f'Loaded preset {cur_preset}: {attack_params}')

        for test_dataset, dataset_path in zip(run_cfg["test_datasets"], run_cfg["dataset_paths"]):

            if run_cfg["attacked_dataset_path"] is None or run_cfg["attacked_dataset_path"] == '':
                attacked_dset_path = None
            else:
                attacked_dset_path = str(Path(run_cfg["attacked_dataset_path"]) / run_cfg["loss_name"] / run_cfg['defence_name'] / str(cur_preset) / run_cfg["attack"] / run_cfg["codec"] / test_dataset)
                Path(attacked_dset_path).mkdir(parents=True, exist_ok=True)
            
            if run_cfg["reconstructed_dataset_path"] is None or run_cfg["reconstructed_dataset_path"] == '':
                reconstructed_dataset_path = None
            else:
                reconstructed_dataset_path = str(Path(run_cfg["reconstructed_dataset_path"]) / run_cfg["loss_name"] / run_cfg['defence_name'] / str(cur_preset) / run_cfg["attack"] / run_cfg["codec"] / test_dataset)
                Path(reconstructed_dataset_path).mkdir(parents=True, exist_ok=True)
            
            ds = ImageFolderDataset(root_dir=dataset_path, )
            dset_dloader = DataLoader(ds, batch_size=run_cfg['batch_size'], shuffle=False, collate_fn=image_folder_collate_fn)

            cur_raw_results, mean_time, attack_time, main_codec_raw_df, main_codec_mean_time = run_robustness_evaluation(
                            run_cfg=run_cfg,
                            dataset=dset_dloader,
                            attack_params=attack_params,
                            attacked_save_path=attacked_dset_path,
                            reconstructed_save_path=reconstructed_dataset_path
                        )
            mean_codec_times.append(mean_time)
            
            total_attack_time += attack_time
            total_attack_calls += len(cur_raw_results)

            cur_raw_results = fill_df_metadata(cur_raw_results, test_dataset, run_cfg["attack"], cur_preset, run_cfg["defence_preset"])
            if main_codec_raw_df is not None:
                mean_main_codec_times.append(main_codec_mean_time)
                main_codec_raw_df = fill_df_metadata(main_codec_raw_df, test_dataset, run_cfg["attack"], cur_preset, run_cfg["defence_preset"])
                main_codec_full_raw_df = pd.concat([main_codec_full_raw_df, main_codec_raw_df]).reset_index(drop=True)

            # Merge raw results
            raw_results_df = pd.concat([raw_results_df, cur_raw_results]).reset_index(drop=True)

            # Calculate scores
            cur_scores = calc_scores_codec(cur_raw_results)
            cur_scores.loc[len(cur_scores)] = {'score':'mean_time', 'value':mean_time}
            cur_scores.loc[len(cur_scores)] = {'score':'mean_attack_time', 'value':attack_time / len(cur_raw_results)}
            cur_scores = fill_df_metadata(cur_scores, test_dataset, run_cfg["attack"], cur_preset, run_cfg["defence_preset"])
            scores_df = pd.concat([scores_df, cur_scores]).reset_index(drop=True)

            if main_codec_raw_df is not None:
                # Calculate scores
                cur_scores_main = calc_scores_codec(main_codec_raw_df)
                cur_scores_main.loc[len(cur_scores_main)] = {'score':'mean_time', 'value':main_codec_mean_time}
                cur_scores_main.loc[len(cur_scores_main)] = {'score':'mean_attack_time', 'value':attack_time / len(main_codec_raw_df)}
                cur_scores_main = fill_df_metadata(cur_scores_main, test_dataset, run_cfg["attack"], cur_preset, run_cfg["defence_preset"])
                scores_main_codec_df = pd.concat([scores_main_codec_df, cur_scores_main]).reset_index(drop=True)

            # SAVE INDEPENDENT CSV FILES FOR EACH DATASET
            cur_dset_log_name = f'{test_dataset}_log.csv'
            cur_dset_rawdata_name = f'{run_cfg["codec"]}_{test_dataset}_test.csv'
            cur_scores.reset_index(drop=True).to_csv(os.path.join(run_cfg["log_file"], cur_dset_log_name))
            cur_raw_results.reset_index(drop=True).to_csv(os.path.join(run_cfg["save_path"], cur_dset_rawdata_name))

            cur_dset_mainc_log_name = f'mainc_{test_dataset}_log.csv'
            cur_dset_mainc_rawdata_name = f'mainc_{test_dataset}_test.csv'
            if run_cfg['undef_main_codec'] is not None:
                cur_scores_main.reset_index(drop=True).to_csv(os.path.join(run_cfg["mainc_log_file"], cur_dset_mainc_log_name))
                main_codec_raw_df.reset_index(drop=True).to_csv(os.path.join(run_cfg["mainc_save_path"], cur_dset_mainc_rawdata_name))
    

    # SAVE FULL CSVS
    total_scores = calc_scores_codec(raw_results_df)
    total_scores.loc[len(total_scores)] = {'score':'mean_time', 'value':np.mean(mean_codec_times)}
    total_scores.loc[len(total_scores)] = {'score':'mean_time', 'value':total_attack_time / total_attack_calls}
    total_scores = fill_df_metadata(total_scores, 'total', 'total', 'total', run_cfg["defence_preset"])
    scores_df = pd.concat([total_scores, scores_df]).reset_index(drop=True)
    if run_cfg['undef_main_codec'] is not None:
        total_scores_main = calc_scores_codec(main_codec_full_raw_df)
        total_scores_main.loc[len(total_scores_main)] = {'score':'mean_time', 'value':np.mean(mean_main_codec_times)}
        total_scores_main.loc[len(total_scores_main)] = {'score':'mean_time', 'value':total_attack_time / total_attack_calls}
        total_scores_main = fill_df_metadata(total_scores_main, 'total', 'total', 'total', run_cfg["defence_preset"])
        scores_main_codec_df = pd.concat([total_scores_main, scores_main_codec_df]).reset_index(drop=True)
    # Save CSVs
    log_name = 'log.csv'
    rawdata_name = f'{run_cfg["codec"]}_test.csv'
    scores_df.reset_index(drop=True).to_csv(os.path.join(run_cfg["log_file"], log_name))
    raw_results_df.reset_index(drop=True).to_csv(os.path.join(run_cfg["save_path"], rawdata_name))

    mainc_log_name = f'mainc_log.csv'
    mainc_rawdata_name = f'mainc_test.csv'
    if run_cfg['undef_main_codec'] is not None:
        scores_main_codec_df.reset_index(drop=True).to_csv(os.path.join(run_cfg["mainc_log_file"], mainc_log_name))
        main_codec_full_raw_df.reset_index(drop=True).to_csv(os.path.join(run_cfg["mainc_save_path"], mainc_rawdata_name))

    
