import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
import argparse
import os
import time
from pathlib import Path

import numpy as np
import pandas as pd
import rasterio
import torch

import architechtures
import test_config as config
from utils import get_logger, to_float, get_datetime_str, handle_labels


def do_args(jupyter=False):
    parser = argparse.ArgumentParser(
        description="Wrapper utility for training and testing land cover models.",
    )
    parser.add_argument(
        "--data-dir",
        type=str,
        help="Path to data directory containing the CSV files",
        default=config.DATA_DIR,
    )
    parser.add_argument(
        "--test-states",
        nargs="+",
        type=str,
        help="States to test",
        default=config.TEST_STATES,
    )
    parser.add_argument(
        "--model-dir",
        action="store",
        type=str,
        default=config.MODEL_DIR,
        help="Path to the directory where Pytorch .pth model file exists"
    )
    parser.add_argument(
        "--model-name",
        action="store",
        type=str,
        default=config.MODEL_NAME,
        help="Name of the Pytorch .pth model to use"
    )
    parser.add_argument(
        "--output-nclasses", type=int, default=config.OUTPUT_NCLASSES, help="Number of target classes",
    )
    parser.add_argument(
        "--input-nchannels", type=int, help="Number of input image channels", default=config.INPUT_NCHANNELS
    )
    parser.add_argument(
        "--batch-size", type=int, help="Batch size when doing prediction using the model", default=config.BATCH_SIZE
    )
    parser.add_argument(
        "--input-size", type=int, help="Model input size when doing prediction using the model",
        default=config.INPUT_SIZE
    )
    parser.add_argument(
        "--hr-label-key",
        type=str,
        help="Path to map from cheasepeake 6 classes label to hr 4 classes label",
        default=config.HR_LABEL_KEY,
    )
    parser.add_argument(
        "--net-model",
        type=str,
        default=config.NET_MODEL,
        choices=["unet", "a-unet"],
        help="Model architecture to use",
    )
    parser.add_argument(
        "--att-hidden", type=int, help="Dimension of attention hidden", default=config.ATT_HIDDEN
    )

    args = parser.parse_args([]) if jupyter else parser.parse_args()
    args.model_path = str(Path(args.model_dir) / Path(args.model_name))
    args.output_dir = str(Path(args.model_dir) / Path(f"Test_{args.model_name[:-4]}"))
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    return args


def compute_acc(lc, output):
    # inputs: H * W, uint8
    roi = (lc > 0) & (output > 0)
    acc_sum = 0
    acc_num = 0
    if np.sum(roi) > 0:
        accuracy = np.sum(lc[roi > 0] == output[roi > 0]) / np.sum(roi)
        acc_sum = np.sum(lc[roi > 0] == output[roi > 0])
        acc_num = np.sum(roi)
    else:
        accuracy = -1

    return accuracy, acc_sum, acc_num


def compute_jaccard(lc, output, nclasses=5):
    # inputs: H * W, uint8, 0 ~ 4
    jaccard_sum = 0
    intersections = []
    unions = []
    for c in range(1, nclasses):
        intersection = np.sum((lc == c) & (output == c)).astype(np.float64) + 1e-6
        union = np.sum((lc == c) | (output == c)).astype(np.float64) + 1e-6
        jaccard = intersection / union
        jaccard_sum += jaccard
        intersections.append(intersection)
        unions.append(union)
    return jaccard_sum / (nclasses - 1), intersections, unions


class TestDataset(torch.utils.data.Dataset):
    def __init__(self, batches, lc_batches):
        self.batches = batches
        self.lc_batches = lc_batches

    def __getitem__(self, item):
        return self.batches[item], self.lc_batches[item]

    def __len__(self):
        return len(self.batches)


class Test:
    def __init__(self, args):
        self.args = args
        self.logger = get_logger(__name__, None, args.output_dir)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        self.overall_acc_sum = 1e-6
        self.overall_acc_num = 1e-6
        self.overall_intersection = []
        self.overall_union = []

        self.tile_acc_sum = 1e-6  # reset at the beginning of each tile
        self.tile_acc_num = 1e-6  # reset at the beginning of each tile
        self.tile_intersection = []  # reset at the beginning of each tile
        self.tile_union = []  # reset at the beginning of each tile

        self.acc_tiles = []
        self.jac_tiles = []

        self.acc_batches = []
        self.jac_batches = []

    def run_on_one_tile(self, model, naip_tile, lc_tile):

        input_size = self.args.input_size
        down_weight_padding = 40
        height = naip_tile.shape[1]
        width = naip_tile.shape[2]

        stride_x = input_size - down_weight_padding * 2
        stride_y = input_size - down_weight_padding * 2

        output = np.zeros((self.args.output_nclasses, height, width), dtype=np.float32)
        counts = np.zeros((height, width), dtype=np.float32) + 0.000000001
        kernel = np.ones((input_size, input_size), dtype=np.float32) * 0.1
        kernel[10:-10, 10:-10] = 1
        kernel[down_weight_padding: down_weight_padding + stride_y,
        down_weight_padding: down_weight_padding + stride_x, ] = 5

        batch = []
        batch_indices = []
        batch_count = 0
        lc_batch = []

        for y_index in list(range(0, height - input_size, stride_y)) + [height - input_size, ]:
            for x_index in list(range(0, width - input_size, stride_x)) + [width - input_size, ]:
                naip_im = naip_tile[
                          :, y_index: y_index + input_size, x_index: x_index + input_size
                          ]
                lc_im = lc_tile[
                        y_index: y_index + input_size, x_index: x_index + input_size
                        ]

                batch.append(naip_im)
                lc_batch.append(lc_im)
                batch_indices.append((y_index, x_index))
                batch_count += 1

        model_output = self.model_predict(model, batch, lc_batch)

        for i, (y, x) in enumerate(batch_indices):
            output[:, y: y + input_size, x: x + input_size] += (
                model_output[i] * kernel[np.newaxis, ...]
            )
            counts[y: y + input_size, x: x + input_size] += kernel

        return output / counts[np.newaxis, ...]

    def model_predict(self, model, batches, lc_batches):
        test_dataset = TestDataset(batches, lc_batches)
        test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=self.args.batch_size, shuffle=False,
                                              num_workers=4)
        model_output = None
        model.eval()
        with torch.no_grad():
            for x, lc in iter(test_dl):
                x = x.to(self.device)
                y_pred = model(x)
                y_pred_np = y_pred.cpu().numpy()
                if model_output is None:
                    model_output = y_pred_np
                else:
                    model_output = np.concatenate((model_output, y_pred_np), axis=0)

                # compute batch metrics
                lc = lc.numpy().astype(np.uint8)
                output_class = np.argmax(y_pred_np, axis=1).astype(np.uint8)

                for i in range(lc.shape[0]):
                    acc, acc_sum, acc_num = compute_acc(lc[i, :, :], output_class[i, :, :])
                    jac, intersections, unions = compute_jaccard(lc[i, :, :], output_class[i, :, :])

                    self.acc_batches.append(acc)
                    self.jac_batches.append(jac)

                    self.tile_acc_sum += acc_sum
                    self.tile_acc_num += acc_num
                    self.tile_intersection.append(intersections)
                    self.tile_union.append(unions)

        return model_output

    def load_tiles(self):
        all_tiles_path = None
        for state in self.args.test_states:
            path = str(Path(self.args.data_dir) / Path(f"{state}_extended-test_tiles.csv"))
            df = pd.read_csv(path)
            fns = df[["naip-new_fn", "lc_fn"]].values
            all_tiles_path = fns if all_tiles_path is None else np.concatenate((all_tiles_path, fns), axis=0)
            self.logger.info(f"Loaded {fns.shape[0]} test tiles from {state}_extended-test_tiles.csv")

        self.logger.info(f"Loaded {all_tiles_path.shape[0]} test tiles in total")

        return all_tiles_path

    def load_model(self):
        model = None
        if self.args.net_model == 'unet':
            model = architechtures.U_Net(in_ch=self.args.input_nchannels,
                                         out_ch=self.args.output_nclasses)
        elif self.args.net_model == "a-unet":
            model = architechtures.Attention_U_Net(in_ch=self.args.input_nchannels,
                                                   out_ch=self.args.output_nclasses,
                                                   att_hidden=self.args.att_hidden,
                                                   save_memory=True)
        checkpoint = torch.load(self.args.model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)
        self.logger.info(f"Loaded model state dict from {self.args.model_path}")

        return model

    def run_on_tiles(self):
        self.logger.info(f"Starting model test at {get_datetime_str()}")
        start_time = get_datetime_str()

        fns = self.load_tiles()
        model = self.load_model()

        for i in range(fns.shape[0]):
            tic = float(time.time())

            # reset tile metrics records
            self.tile_acc_sum = 1e-6
            self.tile_acc_num = 1e-6
            self.tile_intersection = []
            self.tile_union = []

            naip_fn = os.path.join(self.args.data_dir, fns[i][0])
            lc_fn = os.path.join(self.args.data_dir, fns[i][1])

            # Read NAIP high resolution imagery
            naip_fid = rasterio.open(naip_fn, "r")
            naip_profile = naip_fid.meta.copy()
            naip_tile = to_float(naip_fid.read().astype(np.float32))  # 4 * H * W
            naip_fid.close()
            # naip_tile = np.rollaxis(naip_tile, 0, 3)  # commented this line to keep channels first

            # Read Land Cover high resolution labels
            lc_fid = rasterio.open(lc_fn, "r")
            lc_tile = np.squeeze(lc_fid.read()).astype(np.uint8)
            lc_tile = handle_labels(lc_tile, self.args.hr_label_key)  # H * W, 0 ~ 4
            lc_fid.close()

            self.logger.info("Running model on %s\t%d/%d" % (naip_fn, i + 1, len(fns)))
            output = self.run_on_one_tile(model, naip_tile, lc_tile)
            output = output[: self.args.output_nclasses, :, :]  # 5 * H * W

            # ----------------------------------------------------------------
            # Write out each softmax prediction to a separate file
            # ----------------------------------------------------------------
            output_prob_fn = os.path.basename(naip_fn)[:-4] + "_prob.tif"
            current_profile = naip_profile.copy()
            current_profile["driver"] = "GTiff"
            current_profile["dtype"] = "uint8"
            current_profile["count"] = self.args.output_nclasses
            current_profile["compress"] = "lzw"

            # quantize the probabilities
            bins = np.arange(256)
            bins = bins / 255.0

            output_prob = np.digitize(output, bins=bins, right=True).astype(np.uint8)  # 5 * H * W

            output_prob_dir = Path(self.args.output_dir) / Path("probability")
            output_prob_dir.mkdir(parents=True, exist_ok=True)
            output_prob_fn = str(output_prob_dir / Path(output_prob_fn))
            with rasterio.open(output_prob_fn, "w", **current_profile) as f:
                for c in range(self.args.output_nclasses):
                    f.write(output_prob[c, :, :], c + 1)

            # ----------------------------------------------------------------
            # Write out the class predictions
            # ----------------------------------------------------------------
            output_class = np.argmax(output, axis=0).astype(np.uint8)  # H * W
            output_class_fn = os.path.basename(naip_fn)[:-4] + "_class.tif"

            current_profile = naip_profile.copy()
            current_profile["driver"] = "GTiff"
            current_profile["dtype"] = "uint8"
            current_profile["count"] = 1
            current_profile["compress"] = "lzw"

            output_class_dir = Path(self.args.output_dir) / Path("class")
            output_class_dir.mkdir(parents=True, exist_ok=True)
            output_class_fn = str(output_class_dir / Path(output_class_fn))
            f = rasterio.open(output_class_fn, "w", **current_profile)
            f.write(output_class, 1)
            f.close()

            # ----------------------------------------------------------------
            # Compute tile accuracy and jaccard
            # ----------------------------------------------------------------
            tile_acc = self.tile_acc_sum / self.tile_acc_num
            tile_jac = np.sum(np.array(self.tile_intersection), axis=0) / np.sum(np.array(self.tile_union), axis=0)
            tile_jac = np.mean(tile_jac)

            self.acc_tiles.append(tile_acc)
            self.jac_tiles.append(tile_jac)

            self.overall_acc_sum += self.tile_acc_sum
            self.overall_acc_num += self.tile_acc_num
            self.overall_intersection.extend(self.tile_intersection)
            self.overall_union.extend(self.tile_union)

            self.logger.info(f"Completed test on tile {i + 1}/{fns.shape[0]} in {time.time() - tic:0.2f} seconds. "
                             f"Accuracy: {tile_acc:.5f}  Jaccard: {tile_jac:.5f}")

        # compute overall metrics
        overall_acc = self.overall_acc_sum / self.overall_acc_num
        overall_jac = np.sum(np.array(self.overall_intersection), axis=0) / np.sum(np.array(self.overall_union), axis=0)
        overall_jac = np.mean(overall_jac)

        end_time = get_datetime_str()
        self.logger.info("=" * 100)
        self.logger.info(f"Test completed. Started at {start_time}. Ended at {end_time}")
        self.logger.info("=" * 100)
        self.logger.info(f"Overall Accuracy: {overall_acc:.7f}  "
                         f"Overall Jaccard: {overall_jac:.7f}")
        self.logger.info(f"Tile Average Accuracy: {np.mean(self.acc_tiles):.7f}  "
                         f"Tile Average Jaccard: {np.mean(self.jac_tiles):.7f}")
        self.logger.info(f"Batch Average Accuracy: {np.mean(self.acc_batches):.7f}  "
                         f"Batch Average Jaccard: {np.mean(self.jac_batches):.7f}")
        self.logger.info("=" * 100)
        self.logger.info(f"Test results saved at {self.args.output_dir}")
