import os
import time
import shutil
import math

import torch
import numpy as np
from torch.optim import SGD, Adam
import torch.nn as nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode 
import torch.nn.functional as F

from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix

from tensorboardX import SummaryWriter
import logging
import argparse
import json
import pickle

from eval import *


class Averager():

    def __init__(self):
        self.n = 0.0
        self.v = 0.0

    def add(self, v, n=1.0):
        self.v = (self.v * self.n + v * n) / (self.n + n)
        self.n += n

    def item(self):
        return self.v


class Timer():

    def __init__(self):
        self.v = time.time()

    def s(self):
        self.v = time.time()

    def t(self):
        return time.time() - self.v


def time_text(t):
    if t >= 3600:
        return '{:.1f}h'.format(t / 3600)
    elif t >= 60:
        return '{:.1f}m'.format(t / 60)
    else:
        return '{:.1f}s'.format(t)


def setup_console():
    logging.getLogger('').handlers = []
    console = logging.StreamHandler()
    # optional, set the logging level
    console.setLevel(logging.INFO)
    # set a format which is the same for console use
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    # tell the handler to use this format
    console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)

    
def setup_logging(log_file, console=True, filemode='w'):
    #logging.getLogger('').handlers = []
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        filename=log_file,
                        filemode=filemode)
    if console:
        #logging.getLogger('').handlers = []
        console = logging.StreamHandler()
        # optional, set the logging level
        console.setLevel(logging.INFO)
        # set a format which is the same for console use
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)
    return logging

# _log_path = None


# def set_log_path(path):
#     global _log_path
#     _log_path = path


# def log(obj, filename='log.txt'):
#     print(obj)
#     if _log_path is not None:
#         with open(os.path.join(_log_path, filename), 'a') as f:
#             print(obj, file=f)

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace





def make_args_data(config):
    args_data = "{data:s}-TSM{train_scale_max:d}".format(
            data = config.data,
            train_scale_max = config.train_dataset.wrapper.args.scale_max,
            # val_scale_max = config.val_dataset.wrapper.scale_max,
            )
    if hasattr(config, 'train_dataset_2'):
        args_data += "-MD"
    if hasattr(config.train_dataset.wrapper.args, 'band_path'):
        args_data += "-BANDMIN{num_band_min:s}-MAX{num_band_max:s}-SAM{num_band_sample:s}".format(
            num_band_min = str(config.train_dataset.wrapper.args.num_band_min) if hasattr(config.train_dataset.wrapper.args, 'num_band_min') else "",
            num_band_max = str(config.train_dataset.wrapper.args.num_band_max) if hasattr(config.train_dataset.wrapper.args, 'num_band_max') else "",
            num_band_sample = str(config.train_dataset.wrapper.args.num_band_sample) if hasattr(config.train_dataset.wrapper.args, 'num_band_sample') else ""
            )

    return args_data

def make_args_model(config):
    # model args
    args_model = "{name:s}-{enc:s}-{enc_out_dim:d}".format(
            name = config.model.name,
            enc = config.model.args.encoder_spec.name,
            enc_out_dim = config.model.args.encoder_spec.args.G0 if hasattr(config.model.args.encoder_spec.args, 'G0') else 64,
            ) 
    if config.model.args.imnet_spec.name == "mlp":
        args_model += "-{dec:s}-H{dec_hdim:s}".format(
                dec = config.model.args.imnet_spec.name,
                dec_hdim = "_".join([str(x) for x in config.model.args.imnet_spec.args.hidden_list]),
                )
    elif config.model.args.imnet_spec.name == "banddec":
        args_fedec = "{dec:s}-{fedec_name:s}-H{dec_hdim:s}-{fedec_out_dim:d}".format(
                dec = config.model.args.imnet_spec.name,
                fedec_name = config.model.args.imnet_spec.args.fedec_spec.name,
                dec_hdim = "_".join([str(x) for x in config.model.args.imnet_spec.args.fedec_spec.args.hidden_list]),
                fedec_out_dim = config.model.args.imnet_spec.args.fedec_spec.args.out_dim,
                )
        args_bandenc = "{bandenc_name:s}-{bandposenc_type:s}-{freq:d}-{max_radius:.2f}-{min_radius:.6f}-H{band_hdim:s}".format(
                bandenc_name = config.model.args.imnet_spec.args.bandenc_spec.name,
                bandposenc_type = config.model.args.imnet_spec.args.bandenc_spec.args.bandposenc_type,
                freq = config.model.args.imnet_spec.args.bandenc_spec.args.freq,
                max_radius = config.model.args.imnet_spec.args.bandenc_spec.args.max_radius,
                min_radius = config.model.args.imnet_spec.args.bandenc_spec.args.min_radius,
                # freq_init = config.model.args.imnet_spec.args.bandenc_spec.args.freq_init,
                band_hdim = "_".join([str(x) for x in config.model.args.imnet_spec.args.bandenc_spec.args.hidden_list]),
                )
        args_model += "-{args_fedec:s}-{args_bandenc:s}".format(
                args_fedec = args_fedec,
                args_bandenc = args_bandenc
                )
        if config.model.args.imnet_spec.args.bandenc_spec.args.bandposenc_type == 'band_rb_mlp':
            args_bandnerf = "{bandnerf_name:s}-{bandnerf_type:s}-{hidden_list:s}-{act:s}-{num_band_int_sample:d}-{resp_func_type:s}-{resp_func_norm_type:s}".format(
                    bandnerf_name = config.model.args.imnet_spec.args.bandnerf_spec.name,
                    bandnerf_type = config.model.args.imnet_spec.args.bandnerf_spec.args.bandnerf_type,
                    hidden_list = "_".join([str(x) for x in config.model.args.imnet_spec.args.bandnerf_spec.args.hidden_list]),
                    act = config.model.args.imnet_spec.args.bandnerf_spec.args.act,
                    num_band_int_sample = config.model.args.imnet_spec.args.num_band_int_sample,
                    resp_func_type = config.model.args.imnet_spec.args.resp_func_type,
                    resp_func_norm_type = config.model.args.imnet_spec.args.resp_func_norm_type,
                )
            #if config.model.args.imnet_spec.args.band_int_sample_type != "uniform":
            args_bandnerf += "-{band_int_sample_type:s}".format(
                band_int_sample_type = config.model.args.imnet_spec.args.band_int_sample_type
            )
            args_model += "-{args_bandnerf:s}".format(
                    args_bandnerf = args_bandnerf
                    )
        
        
    return args_model

def make_args_opt(config):
    args_opt = "LR{lr:.6f}-{loss:s}".format(
            lr = config.optimizer.args.lr,
            loss = config.loss_fn if hasattr(config, 'loss_fn') else "L1"
            ) 
    return args_opt

def make_model_file_param_args(config_, args):
    config = dict2namespace(config_)
    
    args_data = make_args_data(config)

    args_model = make_args_model(config)
    
    args_opt = make_args_opt(config)

    args_str = f"{args_data}-{args_model}-{args_opt}"
    return args_str

def make_wandb_tags(config_):
    config = dict2namespace(config_)

    args_data = make_args_data(config)

    args_model = make_args_model(config)
    
    args_opt = make_args_opt(config)

    return [args_data, args_model, args_opt]


def ensure_path(path, remove=True):
    basename = os.path.basename(path.rstrip('/'))
    if os.path.exists(path):
        if remove and (basename.startswith('_')
                or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'):
            shutil.rmtree(path)
            os.makedirs(path)
    else:
        os.makedirs(path)


def set_save_path(save_path, remove=True):
    ensure_path(save_path, remove=remove)
    # set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    return writer


def compute_num_params(model, text=False):
    tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
    if text:
        if tot >= 1e6:
            return '{:.1f}M'.format(tot / 1e6)
        else:
            return '{:.1f}K'.format(tot / 1e3)
    else:
        return tot


def make_optimizer(param_list, optimizer_spec, load_sd=False):
    optim_dict = {
        'sgd': SGD,
        'adam': Adam
    }
    Optimizer = optim_dict[optimizer_spec['name']]
    optimizer = Optimizer(param_list, **optimizer_spec['args'])
    if load_sd:
        optimizer.load_state_dict(optimizer_spec['sd'])
    return optimizer


def get_loss_function(loss_fn):
    if loss_fn == "L1":
        return nn.L1Loss()
    elif loss_fn == "L2":
        return nn.MSELoss()
    else:
        raise NotImplementedError


def get_band_interval_by_mid_wave(s_min, s_max, num_band):
    '''
    Args:
        s_min: the minimum wavelength, middle pt of the 1st wavelength interval
        s_max: the maximum wavelength, middle pt of the last wavelength interval
        num_band: number of band
    Return:
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
    '''
    # all band's middle wavelength
    bands = np.linspace(s_min, s_max, num_band)

    step = bands[1] - bands[0]

    start = bands - step/2
    end = bands + step/2

    s_intervals = np.concatenate([np.expand_dims(start, axis = -1), 
                                    np.expand_dims(end, axis = -1)], axis = -1)
    return s_intervals


def get_band_interval(s_min, s_max, num_band):
    '''
    Args:
        s_min: the minimum wavelength, left endpoint of the 1st wavelength interval
        s_max: the maximum wavelength, right endpoint of the last wavelength interval
        num_band: number of band
    Return:
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
    '''
    assert s_max >= s_min
    


    step = (s_max - s_min)/num_band
    r = step/2


    bs = np.linspace(s_min, s_max, num = num_band+1)
    
    s_intervals = np.concatenate([np.expand_dims(bs[:-1], axis = -1), 
                                np.expand_dims(bs[1:], axis = -1)], axis = -1)

    return s_intervals

def make_band_coords(s_intervals, s_min = None, s_max = None):
    '''
    Args：
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
        s_min: the minimum wavelength, left endpoint of the 1st wavelength interval
        s_max: the maximum wavelength, right endpoint of the last wavelength interval
    Return:
        s_coords: shape (num_band, 2), the band interval coordinates, (start, end) of each band, [-1, 1]
    '''

    if s_min is None:
        s_min = np.min(s_intervals)
    if s_max is None:
        s_max = np.max(s_intervals)

    # min-max normalization
    s_coords = (s_intervals - s_min)/(s_max - s_min)

    s_coords = s_coords * 2  - 1

    s_coords = s_coords.astype(np.float32)

    return s_coords


def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    
    x_i = x_min + r + 2r*i
    where r is half of pixel size

    Args:
        shape: shape (num_dim1, num_dim2)
        ranges: shape (num_dim=2, 2), each dimension indicates (min, max)
        flattern: whether or not to flattern the spatial dimention
    Return:
        res: the image pixel coordinate grid
            if flattern = True:
                ret: shape (num_dim1 * num_dim2, 2)
            else:
                shape (num_dim1, num_dim2, 2),
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        # r: the half of pixel size along current axis
        # n: the number of pixel along current axis
        r = (v1 - v0) / (2 * n)
        # seq: v0 + r + 2r*i
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
        
    # ret: shape (num_dim1, num_dim2, 2), the image pixel coordinate grid, 
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        # ret: shape (num_dim1 * num_dim2, 2)
        ret = ret.view(-1, ret.shape[-1])
    return ret


def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (C, H, W)
    """
    C, H, W = img.shape
    # print(C, H, W)
    # coord: shape (H * W, 2), coordinate mat
    coord = make_coord([H, W] )
    # rgb: shape (H * W, C), image value
    rgb = img.reshape(C, -1).permute(1, 0)
    return coord, rgb


def calc_psnr(sr, hr, dataset=None, scale=1, rgb_range=1):
    '''
    Args:
        sr: shape (B, C, H_h, W_h)
        hr: shape (B, C, H_h, W_h)
    '''
    diff = (sr - hr) / rgb_range
    if dataset is not None:
        if dataset == 'benchmark':
            shave = scale
            if diff.size(1) > 1:
                gray_coeffs = [65.738, 129.057, 25.064]
                convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
                diff = diff.mul(convert).sum(dim=1)
        elif dataset == 'div2k':
            shave = scale + 6
        else:
            raise NotImplementedError
        valid = diff[..., shave:-shave, shave:-shave]

        mse = valid.pow(2).mean()
        return -10 * torch.log10(mse)
    else:
        
        # valid: shape (B, C, H_h, W_h)
        valid = diff
        # mse: shape (B, C)
        mse = torch.mean(valid.pow(2), dim = (-2,-1))
        mse[mse==0] = np.power(10, -100/10)
        # psnr: shape (B, C)
        psnr = -10 * torch.log10(mse)
        return psnr.mean()

def calc_psnr_old(sr, hr, dataset=None, scale=1, rgb_range=1):
    '''
    This is LIIF implementation of PSNR, we do not use it
    '''
    diff = (sr - hr) / rgb_range
    if dataset is not None:
        if dataset == 'benchmark':
            shave = scale
            if diff.size(1) > 1:
                gray_coeffs = [65.738, 129.057, 25.064]
                convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
                diff = diff.mul(convert).sum(dim=1)
        elif dataset == 'div2k':
            shave = scale + 6
        else:
            raise NotImplementedError
        valid = diff[..., shave:-shave, shave:-shave]
    else:
        valid = diff
    mse = valid.pow(2).mean()
    return -10 * torch.log10(mse)


def calc_eval_metric(sr, hr, dataset=None, scale=1, rgb_range=1, 
    ratio_ergas=1.0/2, eval_metric_flag = None):
    '''
    Args:
        sr: shape (B, C, H_h, W_h), predict
        hr: shape (B, C, H_h, W_h), ground truth
    '''
    sr = sr.cpu().numpy()
    hr = hr.cpu().numpy()

    eval_metric_flag = update_eval_metric_flag(eval_metric_flag)

    psnr, ergas, sam, ssim = eval_img_metric(y_gt = hr, x_pred = sr, 
        ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    return psnr, ergas, sam, ssim


def update_eval_metric_flag(eval_metric_flag):
    if eval_metric_flag is None:
        eval_metric_flag = {
        'psnr': True,
        'ssim': True,
        'ergas': True,
        'sam': True
        }
    return eval_metric_flag


def resize_fn(img, size):
    # return transforms.ToTensor()(
    #     transforms.Resize(size, Image.BICUBIC)(
    #         transforms.ToPILImage()(img)))
    return transforms.Resize(size, InterpolationMode.BICUBIC)(img)


def np_interpolate_bands(img, num_band):
    '''
    Args:
        img: shape (B, H, W, C)
    Return:
        img_lr: shape (B, H, W, num_band)
    '''
    # img: shape (B, C, H, W)
    img = np.transpose(img, (0,3,1,2))
    B, C, H, W = img.shape

    img_lr_list = []
    for i in range(B):
        # img_lr: shape (num_band, H, W)
        img_lr = interpolate_bands(torch.from_numpy(img[i]), bands = num_band).numpy()
        # img_lr: shape (H, W, num_band)
        img_lr = np.transpose(img_lr, (1,2,0))
        img_lr_list.append(img_lr)

    img_lr = np.concatenate(np.expand_dims(img_lr_list, 0), axis = 0)
    return img_lr

def interpolate_bands(img, bands):
    '''
    Args:
        img: shape (C, H, W)
        bands: final bands number
    Return:
        img_2: shape (bands, H, W), band interpolate image
    '''
    C, H, W = img.shape
    # img_1: shape (1, H, C, W)
    img_1 = img.unsqueeze(0).permute(0,2,1,3)

    # interpolate along the band dimention
    # img_1: shape (1, H, bands, W)
    img_1 = F.interpolate(img_1, 
                size= (bands, W), mode='bicubic', 
                align_corners=False, recompute_scale_factor=False)

    # img_2: shape (bands, H, W)
    img_2 = img_1.permute(0,2,1,3).squeeze(0)
    return img_2


def apply_pca(train_data, test_data):
    '''
    Compute PCA along spectral dimention for training data
    and transform test data
    Return:
        pca: PCA model
        train_data_pca: (H, W, C) -> (1096, 715, 102)
        test_data_pca: (B, h, w, C) -> (8, 128, 128, 102)
    '''
    assert train_data.shape[-1] == test_data.shape[-1]
    train_shape = train_data.shape
    test_shape = test_data.shape
    C = train_data.shape[-1]

    # train_data_: shape (num_train, C)
    train_data_ = train_data.reshape(-1, C)
    test_data_ = test_data.reshape(-1, C)

    pca = PCA(n_components=C)
    train_data_pca = pca.fit_transform(train_data_).reshape(train_shape)

    test_data_pca = pca.transform(test_data_).reshape(test_shape)
    return pca, train_data_pca, test_data_pca


def scale_data(train_data, test_data):
    assert train_data.shape[-1] == test_data.shape[-1]
    train_shape = train_data.shape
    test_shape = test_data.shape
    C = train_data.shape[-1]

    # train_data_: shape (num_train, C)
    train_data_ = train_data.reshape(-1, C)
    test_data_ = test_data.reshape(-1, C)

    scaler = preprocessing.StandardScaler().fit(train_data_)

    train_data_norm = scaler.transform(train_data_).reshape(train_shape)
    test_data_norm = scaler.transform(test_data_).reshape(test_shape)
    return scaler, train_data_norm, test_data_norm


def load_rs_dataset(num_band = 102):
    '''
    Return:
        train_msi: shape (H, W, c) -> (1096, 715, 4)
        train_hsi: shape (H, W, C) -> (1096, 715, 102)
        train_labels: shape (H, W) -> (1096, 715)

        train_msi: shape (8, 128, 128, 4)
        train_hsi: shape (8, 128, 128, 102)
        train_labels: shape (8, 128, 128)
    '''
    
    if num_band == 102:
        band_tag = ""
    else:
        band_tag = f"_band{num_band}"

    print("Load data")
    data_dir = "../dataset_preprocess/dataset/Pavia_Centre"
    train_hsi_dir = f"{data_dir}/train/HSI{band_tag}/"
    train_msi_dir = f"{data_dir}/train/MSI/"
    train_gt_dir = f"{data_dir}/train/GT/"

    test_hsi_dir = f"{data_dir}/test/HSI{band_tag}/"
    test_msi_dir = f"{data_dir}/test/MSI/"
    test_gt_dir = f"{data_dir}/test/GT/"

    train_msi = load_np_file(train_msi_dir, f"pavia_centre-msi_train.npy")
    train_hsi = load_np_file(train_hsi_dir, f"pavia_centre-hsi_train{band_tag}.npy")
    train_labels = load_np_file(train_gt_dir, "pavia_centre-gt_train.npy")

    test_labels = []
    for i in range(8):
        label = load_np_file(test_gt_dir, f"pavia_centre-gt_test_{i}.npy")
        test_labels.append(label)

    test_labels = np.concatenate(np.expand_dims(test_labels, 0), axis = 0)[:,:,:,0]

    test_hsis = []
    for i in range(8):
        hsi = load_np_file(test_hsi_dir, f"pavia_centre-hsi_test_{i}{band_tag}.npy")
        test_hsis.append(hsi)
    test_hsi = np.concatenate(np.expand_dims(test_hsis, 0), axis = 0)

    test_msis = []
    for i in range(8):
        msi = load_np_file(test_msi_dir, f"pavia_centre-msi_test_{i}.npy")
        test_msis.append(msi)
    test_msi = np.concatenate(np.expand_dims(test_msis, 0), axis = 0)

    classes = np.sort(list(set(np.unique(train_labels)) - set([-1])))
    num_class = len(classes)


    return train_msi, train_hsi, train_labels, \
            test_msi, test_hsi, test_labels, classes, num_class


def to_cuda(obj):
    if torch.cuda.is_available():
        obj = obj.cuda()
    return obj

def load_np_file(dir, filename):
    return np.load(os.path.join(dir, filename))

def json_load(filepath):
    with open(filepath, "r") as json_file:
        data = json.load(json_file)
    return data

def json_dump(data, filepath, pretty_format = True):
    with open(filepath, 'w') as fw:
        if pretty_format:
            json.dump(data, fw, indent=2, sort_keys=True)
        else:
            json.dump(data, fw)

def pickle_dump(obj, pickle_filepath):
    with open(pickle_filepath, "wb") as f:
        pickle.dump(obj, f, protocol=2)

def pickle_load(pickle_filepath):
    with open(pickle_filepath, "rb") as f:
        obj = pickle.load(f)
    return obj
