"""
Visualize the results of the sif model
"""

import argparse
import os
import shutil
import torch
from glob import glob


import models
import utils
from misc import empatches
import scipy.io as sio
import numpy as np
from utils import to_pixel_samples, get_band_interval, make_band_coords, resize_fn, interpolate_bands, make_coord

def batched_predict(model, inp, coord, cell, bsize, band_coord):
    '''
    Args:
        model:
        inp: image tensor, shape (N, c, H_l, W_l)
        coord: shape (N, H_h * W_h, 2)
        cell: shape (N, H_h * W_h, 2)
        bsize: int, batch size
        band_coord: shape (N, C or num_band_sample, 2)
    Return:
        pred: shape (N, H_h * W_h, C or num_band_sample)
    '''
    with torch.no_grad():
        model.gen_feat(inp)
        n = coord.shape[1]
        ql = 0
        preds = []
        while ql < n:
            qr = min(ql + bsize, n)
            pred = model.query_rgb(coord[:, ql: qr, :], cell[:, ql: qr, :], band_coord)
            preds.append(pred)
            ql = qr
        pred = torch.cat(preds, dim=1)
    return pred

def create_folder(path):
    if os.path.exists(path):
        print(f"Folder '{path}' already exists. Deleting...")
        try:
            shutil.rmtree(path)
        except OSError as e:
            print(f"Error: {path} : {e.strerror}")

    print(f"Creating folder '{path}'...")
    try:
        os.makedirs(path)
        print(f"Folder '{path}' created successfully.")
    except OSError as e:
        print(f"Error: {path} : {e.strerror}")


def load_img(filepath, load_img_tag='msi'):
    # load hypersepctral MSI or RGB  image
    # x: shape (H, W, C) => (512, 512, 31)
    x = sio.loadmat(filepath)
    x = x[load_img_tag].astype(np.float64)
    # x = torch.tensor(x).float()
    return x

def analyze_band_information(band_path, num_band):
    band_intervals = np.load(band_path)
    # get the min and max band number for different band interpolation purpose

    spec_min = band_intervals[0, 0]
    spec_max = band_intervals[-1, -1]
    assert spec_min <= spec_max
    # cur_band_intervals: shape (num_b, 2), the current band intervals
    cur_band_intervals = get_band_interval(s_min=spec_min, s_max=spec_max,
                                           num_band=num_band)
    # band_coord: shape (num_b, 2), the band interval coordinates
    band_coord = make_band_coords(s_intervals=cur_band_intervals,
                                  s_min=spec_min, s_max=spec_max)
    band_coord = torch.from_numpy(band_coord).float()

    return band_coord



def make_predictions(file_paths, band_path, input_size, model, save_path, scale, num_band, format_="mat"):
    em = empatches.EMPatches()

    band_coord = analyze_band_information(band_path, num_band=num_band)
    for file_path in file_paths:
        file_name = os.path.basename(file_path).split(".")[0]
        if format_ == "mat":
            image = load_img(file_path, "RGB")
        else:
            image = np.load(file_path)

        data_patches, indices = em.extract_patches(image, input_size, overlap=0.1)
        preds = []

        coord = make_coord((input_size, input_size)).cuda()
        cell = torch.ones_like(coord)
        cell[:, 0] *= 2 / input_size
        cell[:, 1] *= 2 / input_size
        for patch in data_patches:
            patch = torch.from_numpy(patch).permute(2, 0, 1).unsqueeze(0)
            patch = resize_fn(patch, int(input_size / scale))

            with torch.no_grad():
                pred = batched_predict(model, ((patch.float() - 0.5) / 0.5).cuda(),
                                       coord.unsqueeze(0).float().cuda(), cell.unsqueeze(0).float().cuda(), bsize=250, band_coord=band_coord.unsqueeze(0).float().cuda())[0]
                pred = (pred * 0.5 + 0.5).clamp(0, 1).view(input_size, input_size, num_band).cpu()

            preds.append(pred)
        final_pred = em.merge_patches(preds, indices, mode="avg")
        np.save(os.path.join(save_path, file_name + fr"_scale{scale}_band{num_band}.npy"), final_pred[:, :, :])
        np.save(os.path.join(save_path, file_name + fr"_rgb_scale{scale}_band{num_band}.npy"), np.concatenate([np.expand_dims(final_pred[:, :, -1], -1), np.expand_dims(final_pred[:, :, 15], -1), np.expand_dims(final_pred[:, :, 4], -1)], axis=-1))


def compute_model_eval_by_scale_band(file_paths, band_path, model_path, scale, num_band, save_path, dataset):
    if dataset=="cave":
        input_size = 512
    else:
        input_size = 128

    if dataset == "cave":
        format_ = "mat"
    else:
        format_ = "npy"
    model_spec = torch.load(model_path)['model']
    model = utils.to_cuda(models.make(model_spec, load_sd=True))
    print("Load model")

    make_predictions(file_paths, band_path, input_size, model, save_path, scale=scale, num_band=num_band, format_=format_)
    print("predictions!done!")

def make_args_parser():
    parser = argparse.ArgumentParser()
    # Dataset name, pavia or cave
    parser.add_argument('--dataset', default="pavia")
    # The folder contains your best model
    parser.add_argument('--model', default="")
    # The path to save your visualization results
    parser.add_argument('--file_path',
                        default="")
    parser.add_argument('--gpu', default='0')
    parser.add_argument("--scale_list", nargs="+", default=[2,3,4,8,10,12,14])
    parser.add_argument("--num_band_list", nargs="+", default=[102])

    return parser


if __name__ == '__main__':

    parser = make_args_parser()
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    scale_list = [int(x) for x in args.scale_list]
    num_band_list = [int(x) for x in args.num_band_list]
    if args.dataset =="pavia":
        file_paths = glob(args.file_path + "/*.npy")
        band_path = "./dataset/Pavia_Centre/waves_102.npy"
    else:
        file_paths = glob(args.file_path + "/*.mat")
        band_path = "./dataset/CAVE/CAVEdata/waves_31.npy"
    save_path = os.path.join(args.model, "predictions")
    create_folder(save_path)
    for scale in scale_list:
        for num_band in num_band_list:
            model_path = os.path.join(args.model, "epoch-best.pth")
            compute_model_eval_by_scale_band(file_paths=file_paths,
                                             model_path=model_path,
                                             scale=scale,
                                             dataset=args.dataset,
                                             save_path=save_path,
                                             num_band=num_band,
                                             band_path=band_path)
