import gc
from matplotlib import pyplot as plt
import psutil
import yaml as _yaml
import os
import uuid
from data_utils.experiments import save_args, AdaptiveSampler, check_mem, Timer
from data_utils.functional import dataclass_from_dict
from nn_compression._interfaces import quantisable
from nn_compression.coding import DeepCABAC, nnc_compress, Bz2
from nn_compression.networks import recursively_find_named_children
import numpy as np
import pandas as pd
from pathlib import Path
import torch.nn as nn
from typing import Callable, Literal, Optional
import torch
from nn_compression.networks import (
    LayerwiseHessianTracker,
    hessians_to,
)
from nn_compression.evaluation import (
    entropy_net_with_overhead,
    DatasetType,
    VISION_NETS,
    NLP_NETS,
    STRATEGY_T,
    DATASET_T,
    NETWORK_T,
    ALL_STRATEGIES,
)
from nn_compression.cv import CvModel, coco, kodak
from nn_compression.networks._hessian import LayerWiseHessian
from nn_compression.nlp import LanguageModel
from nn_compression.quantisation import (
    extract_quant_weights,
    gptq_quantise_network,
    rd_quantise_gptq_order_deepcabac,
    rd_quantise_direct_deepcabac,
    rtn_quantise_network,
)

from cyclopts import App
import copy
from dataclasses import asdict, dataclass

from nn_compression.quantisation import GptqLayer, GptqLayerNoUpdate

LAMBDA_UNICODE = "\u03BB"
app = App()

NNDEBUG = bool(int(os.environ.get("NNDEBUG", "0")))


@dataclass
class Conf:
    nbatches: int
    results_folder: Path
    bits: float
    acc_batches: int
    device: Literal["cpu", "cuda", "mps"]
    strategy: STRATEGY_T
    exp_name: Optional[str]
    rowwise: bool
    network: NETWORK_T
    dataset: DATASET_T
    deepcabac_entropymodel: bool
    train_dataset: DATASET_T
    acc_coarseness: float


def get_network_and_dataset(network_name: str, dataset_name: str):
    if network_name.startswith("resnet"):
        net_enum = CvModel.from_string(f"{network_name}_{dataset_name}")
    elif network_name in ["vgg16", "vit_b_16"]:
        net_enum = CvModel.from_string(network_name)
    elif network_name == "gpt2":
        net_enum = LanguageModel.GPT2
    elif network_name == "gpt2-xl":
        net_enum = LanguageModel.GPT2_XL
    else:
        raise ValueError("Invalid network!")

    net = net_enum.load()
    dataset = net_enum.get_dataset()

    if dataset_name == "kodak":
        assert network_name in VISION_NETS
        dataset = kodak(transform=net_enum.transforms())  # type: ignore

    if dataset_name == "coco":
        assert network_name in VISION_NETS
        coco_root = os.environ.get("COCO_PATH")
        if coco_root is None or not Path(coco_root).exists():
            raise ValueError(
                "COCO_PATH not found. Did you set the environment variable?"
            )
        dataset = coco(coco_root, transforms=net_enum.transforms())  # type: ignore

    return net, dataset


def verification_fn(network_name: NETWORK_T) -> Callable:
    if network_name in ["resnet18", "resnet34", "resnet50", "vgg16"]:
        return lambda _: True
    elif network_name.startswith("gpt2"):
        return lambda n: n.startswith("transformer.h")
    elif network_name == "vit_b_16":
        return lambda n: not (("out_proj") in n)
    else:
        raise ValueError("Invalid network: {network_name}")


@app.command
def yaml(cfg: Path, **overwrite_args):
    if cfg.is_dir():
        for f in cfg.glob("*.yaml"):
            yaml(f, **overwrite_args)
        return

    with open(cfg, "r") as f:
        conf = _yaml.safe_load(f)
    conf["results_folder"] = Path(conf["results_folder"])
    conf.update(overwrite_args)
    main(**conf)


def get_hessian_filepath(
    network: NETWORK_T, dataset: DATASET_T, nbatches: int, results_folder: Path
) -> Path:
    hessian_folder = results_folder / "hessians"
    hessian_filepath = hessian_folder / f"{network}_{dataset}_{nbatches}.pt"
    return hessian_filepath


@app.command
def hessian(
    nbatches: int,
    results_folder: Path,
    network: NETWORK_T,
    dataset: DATASET_T,
    device: Literal["cpu", "cuda", "mps"] = "cpu",
    force: bool = False,
):
    hessian_filepath = get_hessian_filepath(network, dataset, nbatches, results_folder)
    hessian_filepath.parent.mkdir(parents=True, exist_ok=True)
    if hessian_filepath.exists():
        print(f"Hessian already computed at {hessian_filepath}")
        if not force:
            return

    net, train_data = get_network_and_dataset(network, dataset)
    net.train(False)

    # if only nnc, we don't need to compute hessian
    batches_so_far = 0

    print(
        f"Calculating Hessians for network {network} and dataset {dataset} for {nbatches} batches..."
    )
    check_mem()
    net.to(device)
    check_mem()
    with (
        LayerwiseHessianTracker(
            net, save_to=hessian_filepath, is_large_net=network == "gpt2-xl"
        ),
        Timer("Hessian"),
    ):
        while batches_so_far < nbatches:
            for xs, _ in train_data.train_dataloader:
                if batches_so_far >= nbatches:
                    break
                net(xs.to(device))
                batches_so_far += 1
                if batches_so_far % 100 == 0:
                    print(f"{batches_so_far}/{nbatches} batches done", flush=True)
    assert hessian_filepath.exists()
    print("Hessian computed and saved at ", hessian_filepath)


@app.default
def main(
    nbatches: int,
    results_folder: Path,
    bits: Optional[float],
    network: NETWORK_T,
    dataset: DATASET_T,
    acc_batches: int = 5,
    device: Literal["cpu", "cuda", "mps"] = "cpu",
    strategy: STRATEGY_T = [],
    exp_name: Optional[str] = None,
    rowwise: bool = False,
    deepcabac_entropymodel: bool = False,
    train_dataset: Optional[DATASET_T] = None,
    acc_coarseness: float = 0.05,
):
    base_results_folder = results_folder
    nbatches = int(nbatches)
    bits = float(bits) if bits is not None else None
    acc_batches = int(acc_batches)
    acc_coarseness = float(acc_coarseness)
    if "all" in strategy:
        strategy = ALL_STRATEGIES
    if exp_name is not None:
        results_folder /= exp_name
    results_folder /= network
    results_folder /= dataset
    if bits is not None:
        results_folder /= f"{bits}"
    conf = dataclass_from_dict(Conf, locals())
    save_args(asdict(conf), results_folder)

    results_folder.mkdir(parents=True, exist_ok=True)
    torch.manual_seed(18)
    device = torch.device(device)  # type: ignore
    if bits is not None:
        if np.abs(bits - np.round(bits)) < 1e-6:
            bits = int(bits)
        conf.bits = bits

    net, data = get_network_and_dataset(network, dataset)
    train_dataset = train_dataset or dataset

    if "nncodec" in strategy:
        experiment_nnc(net, results_folder, acc_batches, data, conf.network, device)

    net.to(device)
    net.train(False)

    hessian_filepath = get_hessian_filepath(
        network, train_dataset, nbatches, base_results_folder
    )
    if not hessian_filepath.exists():
        print("Hessian not found, computing..")
        hessian(nbatches, base_results_folder, network, train_dataset, device=device)
    hessians = torch.load(hessian_filepath, map_location=device)
    print("Loading hessian into model...")
    LayerWiseHessian.load_into_model(net, hessians)

    net.to("cpu")
    hessians_to(net, "cpu")

    if "gptq" in strategy:
        qnet_gptq = experiment_gptq(net, data, conf)
    else:
        qnet_gptq = None

    if "rtn" in strategy:
        experiment_rtn(net, data, conf)

    # Trace scaling experiments
    def _scale_with_tr(
        m: Literal[
            "alpha_inv_tr",
            "alpha_inv_tr_scale",
            "alpha_inv_tr_rescaled",
            "alpha_inv_tr_unscaled",
        ]
    ):
        alphas = {}
        for n, l in recursively_find_named_children(net):
            if quantisable(l):
                nrows = l.weight.shape[0]
                ncols = l.weight.shape[1]
                # other method
                tr = l.hessian.H.trace()  # / l.hessian.H.shape[0]
                if m == "alpha_inv_tr_rescaled":
                    alphas[n] = ncols / (tr * nrows)
                elif m == "alpha_inv_tr_unscaled":
                    alphas[n] = 1 / tr
                elif m == "alpha_inv_tr":
                    alphas[n] = 1 / (tr * nrows)
                elif m == "alpha_inv_tr_scale":
                    alphas[n] = ncols / tr
                else:
                    raise ValueError("Invalid method")
        # we have lm/a * rate + distortion * H
        # => lm * rate + distortion * H * a
        experiment_rd_quantisation(net, alphas, qnet_gptq, data, m, conf)

    if "alpha_inv_tr_rescaled" in strategy:
        _scale_with_tr("alpha_inv_tr_rescaled")

    if "alpha_inv_tr" in strategy:
        _scale_with_tr("alpha_inv_tr")

    if "alpha_inv_tr_scale" in strategy:
        _scale_with_tr("alpha_inv_tr_scale")

    if "alpha_inv_tr_unscaled" in strategy:
        _scale_with_tr("alpha_inv_tr_unscaled")

    if "uniform" in strategy:
        experiment_rd_quantisation(net, None, qnet_gptq, data, "uniform", conf)

    if "direct_rd" in strategy:
        experiment_rd_quantisation(
            net, None, qnet_gptq, data, "direct_rd", conf, direct_rd=True
        )


class Experiment:
    def __init__(
        self,
        name: str,
        config: Conf,
        verbose: bool = True,
        transpose_deepcabac: bool = False,
    ):
        self.config = config
        self.name = name
        self._reset()
        self.verbose = verbose
        self.transpose_deepcabac = transpose_deepcabac
        if self.verbose:
            print(f"Experiment {name} started\n" + "-" * 80)
        self.last_rows = []
        self.first_save = True

    def record(self, qnet: nn.Module, data: DatasetType, lm: float):
        c = self.config
        acc = data.evaluate(qnet, c.acc_batches, device=c.device)
        f = verification_fn(c.network)

        tmp_filepath = self.config.results_folder / f"{uuid.uuid4()}.nnc"
        coder = DeepCABAC(
            tmp_filepath,
            filter=f,
            transpose=self.transpose_deepcabac,
            per_row_grid=c.rowwise,
        )

        try:
            with Timer("DeepCABAC Encode"):
                coder.encode(qnet)
            with Timer("DeepCABAC Decode"):
                coder.dummy_decode()
        finally:
            tmp_filepath.unlink()
        entropy_deepcabac = coder.bpw
        entropy_all_layers = entropy_net_with_overhead(
            qnet,
            axis_specialisation=0 if c.rowwise else None,
            filter=f,
            regular_grid=True,
        )
        coder_bz2 = Bz2(
            tmp_filepath,
            filter=f,
            transpose=self.transpose_deepcabac,
            per_row_grid=c.rowwise,
        )
        coder_bz2.encode(qnet)
        entropy_bz2 = coder_bz2.bpw

        if self.verbose:
            print(
                f"{c.bits}b, {LAMBDA_UNICODE}={lm:.2e} ({entropy_deepcabac}): {acc}",
                flush=True,
            )

        self.last_rows = []
        keys = list(entropy_all_layers.keys())
        keys.remove("all")
        keys.append("all")
        # we want all to appear at the end so we can report the row
        for k in keys:
            entropy = entropy_all_layers[k]
            numel_quant = entropy.numel - entropy.numel_unquantised_params
            sparsity = 0 if numel_quant == 0 else entropy.nzero_quant / numel_quant
            row = {
                "bits": c.bits,
                "name": self.name,
                "lm": lm,
                #
                "nquant": numel_quant,
                "nunquant": entropy.numel_unquantised_params,
                "nzero_quant": entropy.nzero_quant,
                #
                "shannon_quantised": entropy.entropy_quantised_params,
                "shannon_unquantised": entropy.entropy_unquantised_params,
                "bits_overhead": entropy.overhead,
                #
                "sparsity": sparsity,
                "acc": acc,
                "entropy_deepcabac": entropy_deepcabac,
                "entropy_bz2": entropy_bz2,
                "layer_name": k,
            }
            self.last_rows.append(row)
        #
        self.save()
        return acc, entropy

    def save(self):
        c = self.config
        if self.verbose:
            print(self.last_rows[-1])
        df = pd.DataFrame(self.last_rows)
        if self.config.network == "gpt2":
            df = df.rename(columns={"acc": "ppl"})
        df.to_csv(
            c.results_folder / f"{self.name}.csv",
            mode="w" if self.first_save else "a",
            header=self.first_save,
        )
        self.first_save = False

    def _reset(self):
        self.last_rows = []
        self.first_save = True


@torch.no_grad()
def experiment_direct_rd(net: nn.Module, data: DatasetType, c: Conf):
    print(f"Performing Direct RD quantisation..")
    exp = Experiment("direct_rd", c, transpose_deepcabac=False)
    sampler = AdaptiveSampler((1e-12, 1), max_function_step=c.acc_coarseness)
    if NNDEBUG:
        sampler.max_iter = 2
    sampler.manually_add(0.0)
    while True:
        base_lm = sampler.get_next()
        if base_lm is None:
            break
        # net is internally copied
        rdq_net = rd_quantise_direct_deepcabac(
            copy.deepcopy(net), c.bits, base_lm, per_row_grid=c.rowwise, inplace=True
        )
        acc, _ = exp.record(rdq_net, data, base_lm)
        sampler.record(acc)
    exp.save()

    return rdq_net


@torch.no_grad()
def experiment_gptq(net: nn.Module, data: DatasetType, c: Conf):
    print(f"Quantising network with GPTQ to {c.bits} bits..")
    exp = Experiment("gptq", c, transpose_deepcabac=c.rowwise)

    with torch.no_grad():
        Timer.start("GPTQ")
        qnet = gptq_quantise_network(
            copy.deepcopy(net),
            c.bits,
            None,
            inplace=True,
            per_row_grid=c.rowwise,
            verification_fn=verification_fn(c.network),
        )
        Timer.end("GPTQ")
        exp.record(qnet, data, 0)
    exp.save()

    return qnet


@torch.no_grad()
def experiment_rtn_deepcabac(net, data: DatasetType, c: Conf):
    pass


@torch.no_grad()
def experiment_rtn(net: nn.Module, data: DatasetType, c: Conf):
    print(f"Performing Round-To-Nearest quantisation..")
    exp = Experiment("rtn", c, transpose_deepcabac=False)

    # net is internally copied
    qnet = rtn_quantise_network(net, c.bits, per_row_grid=c.rowwise)
    exp.record(qnet, data, 0)
    exp.save()

    return qnet


@torch.no_grad()
def experiment_rd_quantisation(
    net: nn.Module,
    alphas,
    qnet,
    data: DatasetType,
    name,
    c: Conf,
    direct_rd: bool = False,
):
    gptq_layer = GptqLayerNoUpdate if direct_rd else GptqLayer
    exp = Experiment(name, c, transpose_deepcabac=c.deepcabac_entropymodel)
    sampler = AdaptiveSampler((1e-12, 1), max_function_step=c.acc_coarseness)
    if NNDEBUG:
        sampler.max_iter = 2
    sampler.manually_add(0.0)
    while True:
        check_mem()
        print(f"RAM usage before Quant: {psutil.virtual_memory().percent}", flush=True)
        net_copy = copy.deepcopy(net)
        base_lm = sampler.get_next()
        if base_lm is None:
            break
        if alphas is None:
            lm_alpha = base_lm
        else:
            lm_alpha = {k: base_lm / (a + 1e-10) for k, a in alphas.items()}
        if c.deepcabac_entropymodel:
            with Timer("OPTQ-RD"):
                rdq_net = rd_quantise_gptq_order_deepcabac(
                    net_copy,
                    c.bits,  # type: ignore
                    None,
                    lm_alpha,
                    False,
                    per_row_grid=c.rowwise,
                    inplace=True,
                    verification_fn=verification_fn(c.network),
                    blocksize=64,
                    gptq_class=gptq_layer,
                )
        else:
            raise NotImplementedError("Only deepcabac is supported")
        acc, _ = exp.record(rdq_net, data, base_lm)

        print(f"RAM usage after quant: {psutil.virtual_memory().percent}")
        del rdq_net
        del net_copy
        gc.collect()  # else large networks might be problematic
        print(f"RAM usage after free: {psutil.virtual_memory().percent}", flush=True)

        sampler.record(min(acc, 100))  # max value is added for perplexity
        if base_lm == 0.0 and (acc < 0.2 or acc > 50):
            break
    exp.save()


@app.command
def nnc(
    network: NETWORK_T,
    dataset: DATASET_T,
    results_folder: Path,
    acc_batches: int,
    device: Literal["cpu", "mps", "cuda"] = "cpu",
):
    results_folder = results_folder / network / dataset

    results_folder.mkdir(parents=True, exist_ok=True)
    torch.manual_seed(18)
    device = torch.device(device)  # type: ignore

    net, data = get_network_and_dataset(network, dataset)

    net.to(device)
    net.train(False)
    experiment_nnc(net, results_folder, acc_batches, data, network, device)


def calc_sparsity(net, ver_Fn):
    qnumel = 0
    qnzero = 0
    for n, l in recursively_find_named_children(net):
        if quantisable(l) and ver_Fn(n):
            qnumel += l.weight.numel()
            qnzero += torch.sum(l.weight == 0).item()
    return qnzero / qnumel


@torch.no_grad()
def experiment_nnc(
    net: nn.Module,
    results_folder: Path,
    acc_batches: int,
    data: DatasetType,
    network_type: NETWORK_T,
    device: Literal["cpu", "cuda", "mps"] = "cpu",
):
    print("Compressing network with NNCodec..")

    transpose = False
    only_quantisable = True
    method = "urq"
    print(f"Compressing {net} on CIFAR10\n\n" + "-" * 80)

    qp = -38
    max_iter = 100
    acc = 1.0
    entropies = []
    accs = []
    qps = []
    sparsities = []

    while acc > 0.2 and max_iter > 0:
        qnet, entropy = nnc_compress(
            net,
            qp,
            method,
            only_quantisable,
            transpose,
            verification_fn=verification_fn(network_type),
        )
        acc = data.evaluate(qnet, acc_batches, device=device)
        print(f"QP={qp}, Top-1 Accuracy: {acc:.2f}, Entropy: {entropy:.2f}")

        entropies.append(entropy)
        accs.append(acc)
        qps.append(qp)
        sparsities.append(calc_sparsity(qnet, verification_fn(network_type)))

        df = pd.DataFrame(
            {"qp": qps, "acc": accs, "entropy": entropies, "sparsity": sparsities}
        )
        name = "nncodec-results.csv"
        df.to_csv(results_folder / name)
        print("Sparsity: ", sparsities[-1])

        max_iter -= 1
        qp += 1

    plt.figure()
    plt.plot(df.entropy, df.acc, "--.", label="NNCodec")
    plt.xlabel("Entropy [bits]")
    plt.ylabel("Top-1 Accuracy")
    plt.title("RD Curve of NNCodec on {net}")
    plt.savefig(results_folder / "nncodec.png")
    (results_folder / "config.yaml").unlink()


@app.command
def baseperf(yaml):
    with open(yaml, "r") as f:
        conf = _yaml.safe_load(f)
    if not isinstance(conf["network"], list):
        nets = [conf["network"]]
    else:
        nets = conf["network"]

    results_folder = Path(conf["results_folder"])
    results_file = results_folder / "base_performance.yaml"
    if (results_folder / "results_file.yaml").exists():
        with open(results_file, "r") as f:
            results_so_far = _yaml.safe_load(f)
    else:
        results_so_far = {}
    for net in nets:
        n, d = get_network_and_dataset(net, conf["dataset"])
        acc = d.evaluate(n, conf["acc_batches"], device=conf["device"])
        key = f"{net}_{conf['dataset']}"
        results_so_far.update({key: acc})
        print("key: ", key, "acc: ", acc)
        with open(results_file, "w") as f:
            _yaml.dump(results_so_far, f)


if __name__ == "__main__":
    app()
