import os
import time
import shutil
import math

import torch
import numpy as np
from torch.optim import SGD, Adam
import torch.nn as nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode 
import torch.nn.functional as F
import torch.optim as optim

from tensorboardX import SummaryWriter
import logging
import argparse
import json
import pickle

from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, accuracy_score
import record

# from eval import *


class Averager():

    def __init__(self):
        self.n = 0.0
        self.v = 0.0

    def add(self, v, n=1.0):
        self.v = (self.v * self.n + v * n) / (self.n + n)
        self.n += n

    def item(self):
        return self.v


class Timer():

    def __init__(self):
        self.v = time.time()

    def s(self):
        self.v = time.time()

    def t(self):
        return time.time() - self.v


def time_text(t):
    if t >= 3600:
        return '{:.1f}h'.format(t / 3600)
    elif t >= 60:
        return '{:.1f}m'.format(t / 60)
    else:
        return '{:.1f}s'.format(t)


def setup_console():
    logging.getLogger('').handlers = []
    console = logging.StreamHandler()
    # optional, set the logging level
    console.setLevel(logging.INFO)
    # set a format which is the same for console use
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    # tell the handler to use this format
    console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger('').addHandler(console)

    
def setup_logging(log_file, console=True, filemode='w'):
    #logging.getLogger('').handlers = []
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        filename=log_file,
                        filemode=filemode)
    if console:
        #logging.getLogger('').handlers = []
        console = logging.StreamHandler()
        # optional, set the logging level
        console.setLevel(logging.INFO)
        # set a format which is the same for console use
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        # tell the handler to use this format
        console.setFormatter(formatter)
        # add the handler to the root logger
        logging.getLogger('').addHandler(console)
    return logging

# _log_path = None


# def set_log_path(path):
#     global _log_path
#     _log_path = path


# def log(obj, filename='log.txt'):
#     print(obj)
#     if _log_path is not None:
#         with open(os.path.join(_log_path, filename), 'a') as f:
#             print(obj, file=f)

def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace









def ensure_path(path, remove=True):
    basename = os.path.basename(path.rstrip('/'))
    if os.path.exists(path):
        if remove and (basename.startswith('_')
                or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'):
            shutil.rmtree(path)
            os.makedirs(path)
    else:
        os.makedirs(path)


def set_save_path(save_path, remove=True):
    ensure_path(save_path, remove=remove)
    # set_log_path(save_path)
    writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
    return writer


def compute_num_params(model, text=False):
    tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
    if text:
        if tot >= 1e6:
            return '{:.1f}M'.format(tot / 1e6)
        else:
            return '{:.1f}K'.format(tot / 1e3)
    else:
        return tot


def make_optimizer(param_list, optimizer_spec, load_sd=False):
    optim_dict = {
        'sgd': SGD,
        'adam': Adam
    }
    Optimizer = optim_dict[optimizer_spec['name']]
    optimizer = Optimizer(param_list, **optimizer_spec['args'])
    if load_sd:
        optimizer.load_state_dict(optimizer_spec['sd'])
    return optimizer


def get_loss_function(loss_fn):
    if loss_fn == "L1":
        return nn.L1Loss()
    elif loss_fn == "L2":
        return nn.MSELoss()
    else:
        raise NotImplementedError


def get_band_interval_by_mid_wave(s_min, s_max, num_band):
    '''
    Args:
        s_min: the minimum wavelength, middle pt of the 1st wavelength interval
        s_max: the maximum wavelength, middle pt of the last wavelength interval
        num_band: number of band
    Return:
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
    '''
    # all band's middle wavelength
    bands = np.linspace(s_min, s_max, num_band)

    step = bands[1] - bands[0]

    start = bands - step/2
    end = bands + step/2

    s_intervals = np.concatenate([np.expand_dims(start, axis = -1), 
                                    np.expand_dims(end, axis = -1)], axis = -1)
    return s_intervals


def get_band_interval(s_min, s_max, num_band):
    '''
    Args:
        s_min: the minimum wavelength, left endpoint of the 1st wavelength interval
        s_max: the maximum wavelength, right endpoint of the last wavelength interval
        num_band: number of band
    Return:
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
    '''
    assert s_max >= s_min
    


    step = (s_max - s_min)/num_band
    r = step/2


    bs = np.linspace(s_min, s_max, num = num_band+1)
    
    s_intervals = np.concatenate([np.expand_dims(bs[:-1], axis = -1), 
                                np.expand_dims(bs[1:], axis = -1)], axis = -1)

    return s_intervals

def make_band_coords(s_intervals, s_min = None, s_max = None):
    '''
    Args：
        s_intervals: shape (num_band, 2), the band interval, (start, end) of each band
        s_min: the minimum wavelength, left endpoint of the 1st wavelength interval
        s_max: the maximum wavelength, right endpoint of the last wavelength interval
    Return:
        s_coords: shape (num_band, 2), the band interval coordinates, (start, end) of each band, [-1, 1]
    '''

    if s_min is None:
        s_min = np.min(s_intervals)
    if s_max is None:
        s_max = np.max(s_intervals)

    # min-max normalization
    s_coords = (s_intervals - s_min)/(s_max - s_min)

    s_coords = s_coords * 2  - 1

    s_coords = s_coords.astype(np.float32)

    return s_coords


def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    
    x_i = x_min + r + 2r*i
    where r is half of pixel size

    Args:
        shape: shape (num_dim1, num_dim2)
        ranges: shape (num_dim=2, 2), each dimension indicates (min, max)
        flattern: whether or not to flattern the spatial dimention
    Return:
        res: the image pixel coordinate grid
            if flattern = True:
                ret: shape (num_dim1 * num_dim2, 2)
            else:
                shape (num_dim1, num_dim2, 2),
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        # r: the half of pixel size along current axis
        # n: the number of pixel along current axis
        r = (v1 - v0) / (2 * n)
        # seq: v0 + r + 2r*i
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
        
    # ret: shape (num_dim1, num_dim2, 2), the image pixel coordinate grid, 
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        # ret: shape (num_dim1 * num_dim2, 2)
        ret = ret.view(-1, ret.shape[-1])
    return ret


def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (C, H, W)
    """
    C, H, W = img.shape
    # print(C, H, W)
    # coord: shape (H * W, 2), coordinate mat
    coord = make_coord([H, W] )
    # rgb: shape (H * W, C), image value
    rgb = img.reshape(C, -1).permute(1, 0)
    return coord, rgb





def resize_fn(img, size):
    # return transforms.ToTensor()(
    #     transforms.Resize(size, Image.BICUBIC)(
    #         transforms.ToPILImage()(img)))
    return transforms.Resize(size, InterpolationMode.BICUBIC)(img)


def np_interpolate_bands(img, num_band):
    '''
    Args:
        img: shape (B, H, W, C)
    Return:
        img_lr: shape (B, H, W, num_band)
    '''
    # img: shape (B, C, H, W)
    img = np.transpose(img, (0,3,1,2))
    B, C, H, W = img.shape

    img_lr_list = []
    for i in range(B):
        # img_lr: shape (num_band, H, W)
        img_lr = interpolate_bands(torch.from_numpy(img[i]), bands = num_band).numpy()
        # img_lr: shape (H, W, num_band)
        img_lr = np.transpose(img_lr, (1,2,0))
        img_lr_list.append(img_lr)

    img_lr = np.concatenate(np.expand_dims(img_lr_list, 0), axis = 0)
    return img_lr

def interpolate_bands(img, bands):
    '''
    Args:
        img: shape (C, H, W)
        bands: final bands number
    Return:
        img_2: shape (bands, H, W), band interpolate image
    '''
    C, H, W = img.shape
    # img_1: shape (1, H, C, W)
    img_1 = img.unsqueeze(0).permute(0,2,1,3)

    # interpolate along the band dimention
    # img_1: shape (1, H, bands, W)
    img_1 = F.interpolate(img_1, 
                size= (bands, W), mode='bicubic', 
                align_corners=False, recompute_scale_factor=False)

    # img_2: shape (bands, H, W)
    img_2 = img_1.permute(0,2,1,3).squeeze(0)
    return img_2


def get_sample_idx_per_class(train_labels_, sample_size = 2000):
    un_ids, un_cnts = np.unique(train_labels_, return_counts = True)
    # if sample_size > np.min(un_cnts):
    #     raise Exception('The sample size is larger than the minumum number of itmes per class')
    label2idx = {}
    for uid in un_ids:
        if uid != -1:
            id_list = np.where(train_labels_ == uid)[0]
            # print(uid, id_list)
            if sample_size > len(id_list):
                sample_ids = np.random.choice(id_list, size=sample_size, replace=True)
            else:
                sample_ids = np.random.choice(id_list, size=sample_size, replace=False)
            label2idx[uid] = sample_ids
    return label2idx

def get_label2idx(labels, save_dir, sample_size, data_flag = "train"):
    label2idx_path = f"{save_dir}/{data_flag}_label2sampleidx_{sample_size}.pkl"
    if not os.path.exists(label2idx_path): 
        labels_ = labels.reshape(-1)
        label2idx = get_sample_idx_per_class(labels_, sample_size = sample_size)
        pickle_dump(label2idx, label2idx_path)
    else:
        label2idx = pickle_load(label2idx_path)
    return label2idx


def get_idxs_from_label2idx(label2dix):
    id_list = []
    for i, idxs in label2dix.items():
        id_list += list(idxs)
    return np.array(id_list)

def get_stratified_sample_idx(labels, save_dir, prop, data_flag = "train", classes = []):
    label2idx_path = f"{save_dir}/{data_flag}_label2straifiedsample_{prop}.pkl"
    if not os.path.exists(label2idx_path): 
        labels_ = labels.reshape(-1)
        label2idx = stratified_sampling(ground_truth = labels_, proportion = prop, classes = classes)
        pickle_dump(label2idx, label2idx_path)
    else:
        label2idx = pickle_load(label2idx_path)
    return label2idx

def stratified_sampling(ground_truth, proportion, classes):
    
    id2sample = {}
    labels_loc = {}
    # m = max(ground_truth)
    # classes = np.sort(list(set(np.unique(train_labels)) - set([-1])))
    for uid in classes:
        indexes = np.where(ground_truth == uid)[0]
        # [
        #     j for j, x in enumerate(ground_truth.ravel().tolist())
        #     if x == i + 1
        # ]
        np.random.shuffle(indexes)
        labels_loc[uid] = indexes
        if proportion != 1:
            nb_val = max(int(proportion * len(indexes)), 3)
        else:
            nb_val = 0
        id2sample[uid] = indexes[:nb_val]
        
    # train_indexes = []
    # test_indexes = []
    # samples = []
    # for uid in classes:
    #     samples += list(id2sample[uid])
    # # np.random.shuffle(train_indexes)
    # # np.random.shuffle(test_indexes)
    # np.random.shuffle(samples)
    return id2sample

def to_cuda(obj):
    if torch.cuda.is_available():
        obj = obj.cuda()
    return obj

def load_np_file(dir, filename):
    return np.load(os.path.join(dir, filename))

def json_load(filepath):
    with open(filepath, "r") as json_file:
        data = json.load(json_file)
    return data

def json_dump(data, filepath, pretty_format = True):
    with open(filepath, 'w') as fw:
        if pretty_format:
            json.dump(data, fw, indent=2, sort_keys=True)
        else:
            json.dump(data, fw)

def pickle_dump(obj, pickle_filepath):
    with open(pickle_filepath, "wb") as f:
        pickle.dump(obj, f, protocol=2)

def pickle_load(pickle_filepath):
    with open(pickle_filepath, "rb") as f:
        obj = pickle.load(f)
    return obj

def get_optim(net, opt, lr):
    if opt == 'diffgrad':
        optimizer = optim2.DiffGrad(
            net.parameters(),
            lr=lr,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0)  # weight_decay=0.0001)
    elif opt == 'adam':
        optimizer = optim.Adam(
            net.parameters(),
            lr=lr,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0)
    return optimizer

def get_save_path(args):
    save_path = '''./models/S3KAIResNet_{dataset:s}_SAM{sample_size:d}_{test_prop:.3f}_band{num_band:d}_P{img_rows:d}_kernel{kernel:d}_LR{lr:.6f}_{opt:s}'''.format(
                    img_rows = 2 * args.patch + 1,
                    dataset = args.dataset,
                    sample_size = args.sample_size,
                    test_prop = args.test_prop,
                    num_band = args.num_band,
                    lr= args.lr,
                    opt = args.optimizer,
                    kernel = args.kernel,
                ) 
    return save_path

def compute_eval_metric(model_pred):
    acc = model_pred["acc"]
    pred = model_pred["pred"].reshape(-1)
    gt = model_pred["gt"].reshape(-1)

    # print(pred.shape, gt.shape)
    overall_acc = metrics.accuracy_score(pred, gt)
    confusion_matrix = metrics.confusion_matrix(pred, gt)
    each_acc, average_acc = record.aa_and_each_accuracy(confusion_matrix)
    kappa = metrics.cohen_kappa_score(pred, gt)
    return overall_acc, average_acc, kappa