import os
import time
import shutil
import math

import torch
import numpy as np
from torch.optim import SGD, Adam, AdamW
import torch.nn as nn
from torchvision import transforms
# from torchvision.transforms import InterpolationMode
import torch.nn.functional as F

from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix

from tensorboardX import SummaryWriter
import logging
import argparse
import json
import pickle
import PIL.Image
from eval import *


class Averager():

    def __init__(self):
        self.n = 0.0
        self.v = 0.0

    def add(self, v, n=1.0):
        self.v = (self.v * self.n + v * n) / (self.n + n)
        self.n += n

    def item(self):
        return self.v


class Timer():

    def __init__(self):
        self.v = time.time()

    def s(self):
        self.v = time.time()

    def t(self):
        return time.time() - self.v


def time_text(t):
    if t >= 3600:
        return '{:.1f}h'.format(t / 3600)
    elif t >= 60:
        return '{:.1f}m'.format(t / 60)
    else:
        return '{:.1f}s'.format(t)


def setup_console():
    logging.getLogger('').handlers = []
    console = logging.StreamHandler()
    # optional, set the logging level
    console.setLevel(logging.INFO)
    # set a format which is the same for console use
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    # tell the handler to use this format
    console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)

    
def setup_logging(log_file, console=True, filemode='w'):
    #logging.getLogger('').handlers = []
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        filename=log_file,
                        filemode=filemode)
    if console:
        #logging.getLogger('').handlers = []
        console = logging.StreamHandler()
        # optional, set the logging level
        console.setLevel(logging.INFO)
        # set a format which is the same for console use
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)
    return logging

# _log_path = None


# def set_log_path(path):
#     global _log_path
#     _log_path = path


# def log(obj, filename='log.txt'):
#     print(obj)
#     if _log_path is not None:
#         with open(os.path.join(_log_path, filename), 'a') as f:
#             print(obj, file=f)

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace





def make_args_data(config):
    args_data = "{data:s}-TSM{train_scale_max:d}".format(
            data = config.data,
            train_scale_max = config.train_dataset.wrapper.args.scale_max,
            # val_scale_max = config.val_dataset.wrapper.scale_max,
            )
    if hasattr(config, 'train_dataset_2'):
        args_data += "-MD"
    if hasattr(config.train_dataset.wrapper.args, 'band_path'):
        args_data += "-BANDMIN{num_band_min:s}-MAX{num_band_max:s}-SAM{num_band_sample:s}".format(
            num_band_min = str(config.train_dataset.wrapper.args.num_band_min) if hasattr(config.train_dataset.wrapper.args, 'num_band_min') else "",
            num_band_max = str(config.train_dataset.wrapper.args.num_band_max) if hasattr(config.train_dataset.wrapper.args, 'num_band_max') else "",
            num_band_sample = str(config.train_dataset.wrapper.args.num_band_sample) if hasattr(config.train_dataset.wrapper.args, 'num_band_sample') else ""
            )

    return args_data

def make_args_model(config):
    # model args
    args_model = "{name:s}-{enc:s}-{enc_out_dim:d}".format(
            name = config.model.name,
            enc = config.model.args.encoder_spec.name,
            enc_out_dim = config.model.args.encoder_spec.args.G0 if hasattr(config.model.args.encoder_spec.args, 'G0') else 64,
            )
    try:
        if config.model.args.imnet_spec.name == "mlp":
            args_model += "-{dec:s}-H{dec_hdim:s}".format(
                    dec = config.model.args.imnet_spec.name,
                    dec_hdim = "_".join([str(x) for x in config.model.args.imnet_spec.args.hidden_list]),
                    )
        elif config.model.args.imnet_spec.name == "banddec":
            args_fedec = "{dec:s}-{fedec_name:s}-H{dec_hdim:s}-{fedec_out_dim:d}".format(
                    dec = config.model.args.imnet_spec.name,
                    fedec_name = config.model.args.imnet_spec.args.fedec_spec.name,
                    dec_hdim = "_".join([str(x) for x in config.model.args.imnet_spec.args.fedec_spec.args.hidden_list]),
                    fedec_out_dim = config.model.args.imnet_spec.args.fedec_spec.args.out_dim,
                    )
            args_bandenc = "{bandenc_name:s}-{bandposenc_type:s}-{freq:d}-{max_radius:.2f}-{min_radius:.6f}-H{band_hdim:s}".format(
                    bandenc_name = config.model.args.imnet_spec.args.bandenc_spec.name,
                    bandposenc_type = config.model.args.imnet_spec.args.bandenc_spec.args.bandposenc_type,
                    freq = config.model.args.imnet_spec.args.bandenc_spec.args.freq,
                    max_radius = config.model.args.imnet_spec.args.bandenc_spec.args.max_radius,
                    min_radius = config.model.args.imnet_spec.args.bandenc_spec.args.min_radius,
                    # freq_init = config.model.args.imnet_spec.args.bandenc_spec.args.freq_init,
                    band_hdim = "_".join([str(x) for x in config.model.args.imnet_spec.args.bandenc_spec.args.hidden_list]),
                    )
            args_model += "-{args_fedec:s}-{args_bandenc:s}".format(
                    args_fedec = args_fedec,
                    args_bandenc = args_bandenc
                    )
            if config.model.args.imnet_spec.args.bandenc_spec.args.bandposenc_type == 'band_rb_mlp':
                args_bandnerf = "{bandnerf_name:s}-{bandnerf_type:s}-{hidden_list:s}-{act:s}-{num_band_int_sample:d}-{resp_func_type:s}-{resp_func_norm_type:s}".format(
                        bandnerf_name = config.model.args.imnet_spec.args.bandnerf_spec.name,
                        bandnerf_type = config.model.args.imnet_spec.args.bandnerf_spec.args.bandnerf_type,
                        hidden_list = "_".join([str(x) for x in config.model.args.imnet_spec.args.bandnerf_spec.args.hidden_list]),
                        act = config.model.args.imnet_spec.args.bandnerf_spec.args.act,
                        num_band_int_sample = config.model.args.imnet_spec.args.num_band_int_sample,
                        resp_func_type = config.model.args.imnet_spec.args.resp_func_type,
                        resp_func_norm_type = config.model.args.imnet_spec.args.resp_func_norm_type,
                    )
                #if config.model.args.imnet_spec.args.band_int_sample_type != "uniform":
                args_bandnerf += "-{band_int_sample_type:s}".format(
                    band_int_sample_type = config.model.args.imnet_spec.args.band_int_sample_type
                )
                args_model += "-{args_bandnerf:s}".format(
                        args_bandnerf = args_bandnerf
                        )
    except:
        if config.model.args.band_nerf.name == "banddec":
            args_fedec = "{dec:s}-{fedec_name:s}-H{dec_hdim:s}-{fedec_out_dim:d}".format(
                dec=config.model.args.band_nerf.name,
                fedec_name=config.model.args.band_nerf.args.fedec_spec.name,
                dec_hdim="_".join([str(x) for x in config.model.args.band_nerf.args.fedec_spec.args.hidden_list]),
                fedec_out_dim=config.model.args.band_nerf.args.fedec_spec.args.out_dim,
            )
            args_bandenc = "{bandenc_name:s}-{bandposenc_type:s}-{freq:d}-{max_radius:.2f}-{min_radius:.6f}-H{band_hdim:s}".format(
                bandenc_name=config.model.args.band_nerf.args.bandenc_spec.name,
                bandposenc_type=config.model.args.band_nerf.args.bandenc_spec.args.bandposenc_type,
                freq=config.model.args.band_nerf.args.bandenc_spec.args.freq,
                max_radius=config.model.args.band_nerf.args.bandenc_spec.args.max_radius,
                min_radius=config.model.args.band_nerf.args.bandenc_spec.args.min_radius,
                # freq_init = config.model.args.imnet_spec.args.bandenc_spec.args.freq_init,
                band_hdim="_".join([str(x) for x in config.model.args.band_nerf.args.bandenc_spec.args.hidden_list]),
            )
            args_model += "-{args_fedec:s}-{args_bandenc:s}".format(
                args_fedec=args_fedec,
                args_bandenc=args_bandenc
            )
        else:
            raise ValueError
        
        
    return args_model

def make_args_opt(config):
    args_opt = "LR{lr:.6f}-{loss:s}".format(
            lr = config.optimizer.args.lr,
            loss = config.loss_fn if hasattr(config, 'loss_fn') else "L1"
            ) 
    return args_opt

def make_model_file_param_args(config_, args):
    config = dict2namespace(config_)
    
    args_data = make_args_data(config)

    args_model = make_args_model(config)
    
    args_opt = make_args_opt(config)

    args_str = f"{args_data}-{args_model}-{args_opt}"
    return args_str

def make_wandb_tags(config_):
    config = dict2namespace(config_)

    args_data = make_args_data(config)

    args_model = make_args_model(config)
    
    args_opt = make_args_opt(config)

    return [args_data, args_model, args_opt]


def ensure_path(path, remove=True):
    basename = os.path.basename(path.rstrip('/'))
    if os.path.exists(path):
        if remove and (basename.startswith('_')
                or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'):
            shutil.rmtree(path)
            os.makedirs(path)
    else:
        os.makedirs(path)


def set_save_path(save_path, remove=True):
    ensure_path(save_path, remove=remove)
    # set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    return writer


def compute_num_params(model, text=False):
    tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
    if text:
        if tot >= 1e6:
            return '{:.1f}M'.format(tot / 1e6)
        else:
            return '{:.1f}K'.format(tot / 1e3)
    else:
        return tot


def make_optimizer(param_list, optimizer_spec, load_sd=False):
    optim_dict = {
        'sgd': SGD,
        'adam': Adam,
        "adamw": AdamW
    }
    Optimizer = optim_dict[optimizer_spec['name']]
    optimizer = Optimizer(param_list, **optimizer_spec['args'])
    if load_sd:
        optimizer.load_state_dict(optimizer_spec['sd'])
    return optimizer


def get_loss_function(loss_fn):
    if loss_fn == "L1":
        return nn.L1Loss()
    elif loss_fn == "L2":
        return nn.MSELoss()
    else:
        raise NotImplementedError


def get_band_interval_by_mid_wave(s_min, s_max, num_band):
    '''
    Args:
        s_min: the minimum wavelength, middle pt of the 1st wavelength interval
        s_max: the maximum wavelength, middle pt of the last wavelength interval
        num_band: number of band
    Return:
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
    '''
    # all band's middle wavelength
    bands = np.linspace(s_min, s_max, num_band)

    step = bands[1] - bands[0]

    start = bands - step/2
    end = bands + step/2

    s_intervals = np.concatenate([np.expand_dims(start, axis = -1), 
                                    np.expand_dims(end, axis = -1)], axis = -1)
    return s_intervals


def get_band_interval(s_min, s_max, num_band):
    '''
    Args:
        s_min: the minimum wavelength, left endpoint of the 1st wavelength interval
        s_max: the maximum wavelength, right endpoint of the last wavelength interval
        num_band: number of band
    Return:
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
    '''
    assert s_max >= s_min
    


    step = (s_max - s_min)/num_band
    r = step/2

    bs = np.linspace(s_min, s_max, num = num_band+1)
    
    s_intervals = np.concatenate([np.expand_dims(bs[:-1], axis = -1), 
                                np.expand_dims(bs[1:], axis = -1)], axis = -1)

    return s_intervals

def make_band_coords(s_intervals, s_min = None, s_max = None):
    '''
    Args：
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
        s_min: the minimum wavelength, left endpoint of the 1st wavelength interval
        s_max: the maximum wavelength, right endpoint of the last wavelength interval
    Return:
        s_coords: shape (num_band, 2), the band interval coordinates, (start, end) of each band, [-1, 1]
    '''

    if s_min is None:
        s_min = np.min(s_intervals)
    if s_max is None:
        s_max = np.max(s_intervals)

    # min-max normalization
    s_coords = (s_intervals - s_min)/(s_max - s_min)

    s_coords = s_coords * 2  - 1

    s_coords = s_coords.astype(np.float32)

    return s_coords


def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    
    x_i = x_min + r + 2r*i
    where r is half of pixel size

    Args:
        shape: shape (num_dim1, num_dim2)
        ranges: shape (num_dim=2, 2), each dimension indicates (min, max)
        flattern: whether or not to flattern the spatial dimention
    Return:
        res: the image pixel coordinate grid
            if flattern = True:
                ret: shape (num_dim1 * num_dim2, 2)
            else:
                shape (num_dim1, num_dim2, 2),
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        # r: the half of pixel size along current axis
        # n: the number of pixel along current axis
        r = (v1 - v0) / (2 * n)
        # seq: v0 + r + 2r*i
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
        
    # ret: shape (num_dim1, num_dim2, 2), the image pixel coordinate grid, 
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        # ret: shape (num_dim1 * num_dim2, 2)
        ret = ret.view(-1, ret.shape[-1])
    return ret


def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (C, H, W)
    """
    C, H, W = img.shape
    # print(C, H, W)
    # coord: shape (H * W, 2), coordinate mat
    coord = make_coord([H, W] )
    # rgb: shape (H * W, C), image value
    rgb = img.reshape(C, -1).permute(1, 0)
    return coord, rgb


def calc_psnr(sr, hr, dataset=None, scale=1, rgb_range=1):
    '''
    Args:
        sr: shape (B, C, H_h, W_h)
        hr: shape (B, C, H_h, W_h)
    '''
    diff = (sr - hr) / rgb_range
    if dataset is not None:
        if dataset == 'benchmark':
            shave = scale
            if diff.size(1) > 1:
                gray_coeffs = [65.738, 129.057, 25.064]
                convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
                diff = diff.mul(convert).sum(dim=1)
        elif dataset == 'div2k':
            shave = scale + 6
        else:
            raise NotImplementedError
        valid = diff[..., shave:-shave, shave:-shave]

        mse = valid.pow(2).mean()
        return -10 * torch.log10(mse)
    else:
        
        # valid: shape (B, C, H_h, W_h)
        valid = diff
        # mse: shape (B, C)
        mse = torch.mean(valid.pow(2), dim = (-2,-1))
        mse[mse==0] = np.power(10, -100/10)
        # psnr: shape (B, C)
        psnr = -10 * torch.log10(mse)
        return psnr.mean()

def calc_psnr_old(sr, hr, dataset=None, scale=1, rgb_range=1):
    '''
    This is LIIF implementation of PSNR, we do not use it
    '''
    diff = (sr - hr) / rgb_range
    if dataset is not None:
        if dataset == 'benchmark':
            shave = scale
            if diff.size(1) > 1:
                gray_coeffs = [65.738, 129.057, 25.064]
                convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
                diff = diff.mul(convert).sum(dim=1)
        elif dataset == 'div2k':
            shave = scale + 6
        else:
            raise NotImplementedError
        valid = diff[..., shave:-shave, shave:-shave]
    else:
        valid = diff
    mse = valid.pow(2).mean()
    return -10 * torch.log10(mse)


def calc_eval_metric(sr, hr, dataset=None, scale=1, rgb_range=1, 
    ratio_ergas=1.0/2, eval_metric_flag = None):
    '''
    Args:
        sr: shape (B, C, H_h, W_h), predict
        hr: shape (B, C, H_h, W_h), ground truth
    '''
    sr = sr.cpu().numpy()
    hr = hr.cpu().numpy()

    eval_metric_flag = update_eval_metric_flag(eval_metric_flag)

    psnr, ergas, sam, ssim = eval_img_metric(y_gt = hr, x_pred = sr, 
        ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    return psnr, ergas, sam, ssim


def update_eval_metric_flag(eval_metric_flag):
    if eval_metric_flag is None:
        eval_metric_flag = {
        'psnr': True,
        'ssim': True,
        'ergas': True,
        'sam': True
        }
    return eval_metric_flag


def resize_fn(img, size):
    # return transforms.ToTensor()(
    #     transforms.Resize(size, Image.BICUBIC)(
    #         transforms.ToPILImage()(img)))
    return transforms.Resize(size, PIL.Image.BICUBIC)(img)


def np_interpolate_bands(img, num_band):
    '''
    Args:
        img: shape (B, H, W, C)
    Return:
        img_lr: shape (B, H, W, num_band)
    '''
    # img: shape (B, C, H, W)
    img = np.transpose(img, (0,3,1,2))
    B, C, H, W = img.shape

    img_lr_list = []
    for i in range(B):
        # img_lr: shape (num_band, H, W)
        img_lr = interpolate_bands(torch.from_numpy(img[i]), bands = num_band).numpy()
        # img_lr: shape (H, W, num_band)
        img_lr = np.transpose(img_lr, (1,2,0))
        img_lr_list.append(img_lr)

    img_lr = np.concatenate(np.expand_dims(img_lr_list, 0), axis = 0)
    return img_lr

def interpolate_bands(img, bands):
    '''
    Args:
        img: shape (C, H, W)
        bands: final bands number
    Return:
        img_2: shape (bands, H, W), band interpolate image
    '''

    C, H, W = img.shape
    # img_1: shape (1, H, C, W)
    img_1 = img.unsqueeze(0).permute(0,2,1,3)

    # interpolate along the band dimention
    # img_1: shape (1, H, bands, W)
    img_1 = F.interpolate(img_1,
                size= (bands, W), mode='bicubic', 
                align_corners=False, recompute_scale_factor=False)

    # img_2: shape (bands, H, W)
    img_2 = img_1.permute(0,2,1,3).squeeze(0)
    return img_2


def apply_pca(train_data, test_data):
    '''
    Compute PCA along spectral dimention for training data
    and transform test data
    Return:
        pca: PCA model
        train_data_pca: (H, W, C) -> (1096, 715, 102)
        test_data_pca: (B, h, w, C) -> (8, 128, 128, 102)
    '''
    assert train_data.shape[-1] == test_data.shape[-1]
    train_shape = train_data.shape
    test_shape = test_data.shape
    C = train_data.shape[-1]

    # train_data_: shape (num_train, C)
    train_data_ = train_data.reshape(-1, C)
    test_data_ = test_data.reshape(-1, C)

    pca = PCA(n_components=C)
    train_data_pca = pca.fit_transform(train_data_).reshape(train_shape)

    test_data_pca = pca.transform(test_data_).reshape(test_shape)
    return pca, train_data_pca, test_data_pca


def scale_data(train_data, test_data):
    assert train_data.shape[-1] == test_data.shape[-1]
    train_shape = train_data.shape
    test_shape = test_data.shape
    C = train_data.shape[-1]

    # train_data_: shape (num_train, C)
    train_data_ = train_data.reshape(-1, C)
    test_data_ = test_data.reshape(-1, C)

    scaler = preprocessing.StandardScaler().fit(train_data_)

    train_data_norm = scaler.transform(train_data_).reshape(train_shape)
    test_data_norm = scaler.transform(test_data_).reshape(test_shape)
    return scaler, train_data_norm, test_data_norm


def load_rs_dataset(num_band = 102):
    '''
    Return:
        train_msi: shape (H, W, c) -> (1096, 715, 4)
        train_hsi: shape (H, W, C) -> (1096, 715, 102)
        train_labels: shape (H, W) -> (1096, 715)

        train_msi: shape (8, 128, 128, 4)
        train_hsi: shape (8, 128, 128, 102)
        train_labels: shape (8, 128, 128)
    '''
    
    if num_band == 102:
        band_tag = ""
    else:
        band_tag = f"_band{num_band}"

    print("Load data")
    data_dir = "../dataset_preprocess/dataset/Pavia_Centre"
    train_hsi_dir = f"{data_dir}/train/HSI{band_tag}/"
    train_msi_dir = f"{data_dir}/train/MSI/"
    train_gt_dir = f"{data_dir}/train/GT/"

    test_hsi_dir = f"{data_dir}/test/HSI{band_tag}/"
    test_msi_dir = f"{data_dir}/test/MSI/"
    test_gt_dir = f"{data_dir}/test/GT/"

    train_msi = load_np_file(train_msi_dir, f"pavia_centre-msi_train.npy")
    train_hsi = load_np_file(train_hsi_dir, f"pavia_centre-hsi_train{band_tag}.npy")
    train_labels = load_np_file(train_gt_dir, "pavia_centre-gt_train.npy")

    test_labels = []
    for i in range(8):
        label = load_np_file(test_gt_dir, f"pavia_centre-gt_test_{i}.npy")
        test_labels.append(label)

    test_labels = np.concatenate(np.expand_dims(test_labels, 0), axis = 0)[:,:,:,0]

    test_hsis = []
    for i in range(8):
        hsi = load_np_file(test_hsi_dir, f"pavia_centre-hsi_test_{i}{band_tag}.npy")
        test_hsis.append(hsi)
    test_hsi = np.concatenate(np.expand_dims(test_hsis, 0), axis = 0)

    test_msis = []
    for i in range(8):
        msi = load_np_file(test_msi_dir, f"pavia_centre-msi_test_{i}.npy")
        test_msis.append(msi)
    test_msi = np.concatenate(np.expand_dims(test_msis, 0), axis = 0)

    classes = np.sort(list(set(np.unique(train_labels)) - set([-1])))
    num_class = len(classes)


    return train_msi, train_hsi, train_labels, \
            test_msi, test_hsi, test_labels, classes, num_class


def to_cuda(obj):
    if torch.cuda.is_available():
        obj = obj.cuda()
    return obj

def load_np_file(dir, filename):
    return np.load(os.path.join(dir, filename))

def json_load(filepath):
    with open(filepath, "r") as json_file:
        data = json.load(json_file)
    return data

def json_dump(data, filepath, pretty_format = True):
    with open(filepath, 'w') as fw:
        if pretty_format:
            json.dump(data, fw, indent=2, sort_keys=True)
        else:
            json.dump(data, fw)

def pickle_dump(obj, pickle_filepath):
    with open(pickle_filepath, "wb") as f:
        pickle.dump(obj, f, protocol=2)

def pickle_load(pickle_filepath):
    with open(pickle_filepath, "rb") as f:
        obj = pickle.load(f)
    return obj
