""" Run weighted retraining for shapes with the optimal model """
import os.path
import sys

# sys.path.append(os.path.abspath(os.path.join(os.path.join(os.path.dirname(__file__), '..'), '..')))


import logging
import itertools
from tqdm.auto import tqdm
import argparse
from pathlib import Path
import numpy as np
import torch
import pytorch_lightning as pl

# from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# My imports
from weighted_retraining.shapes.shapes_data import WeightedNumpyDataset, WeighNumpyDataset
from weighted_retraining.shapes.shapes_model_base import ShapesVAE as oldVAE
from weighted_retraining.shapes.shapes_model_metric import ShapesVAE
from weighted_retraining import utils
from weighted_retraining.opt_scripts import base as wr_base

# from ..shapes.shapes_data import WeightedNumpyDataset
# from ..shapes.shapes_model import ShapesVAE
# from .. import utils
# import base as wr_base

from torch.utils.data import Dataset, DataLoader
from torch.distributions.multivariate_normal import MultivariateNormal



def retrain_model(model, datamodule, save_dir, version_str, num_epochs, gpu):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    model.train()
    train_loss = 0
    # dataset = CustomDataset(x.detach(), y.detach())
    train_loader = DataLoader(datamodule, batch_size=5, shuffle=False, drop_last=True)

    # Handle fractional epochs
    if num_epochs < 1:
        max_epochs = 1
        limit_train_batches = num_epochs
    elif int(num_epochs) == num_epochs:
        max_epochs = int(num_epochs)
        limit_train_batches = 1.0
    else:
        raise ValueError(f"invalid num epochs {num_epochs}")

    for epoch in range(max_epochs):
        # for batch_idx, (batch_x) in enumerate(train_loader):
        for batch_idx, (batch_x, batch_y, weight) in enumerate(train_loader):
            # batch_x = batch_x.to(device)
            optimizer.zero_grad()

            loss, recon_x, mu, logstd, z = model.forward_with_reconx(batch_x.unsqueeze(1), batch_y)

            # find positive and negative pairs
            distance_y = torch.cdist(batch_y.unsqueeze(1)+0.0, batch_y.unsqueeze(1)+0.0, p=2)
            distance_z = torch.cdist(z, z, p=2)

            # Find the positive elements (a, b)
            value_sorted, _ = torch.sort(distance_y.view(-1), descending=False)
            thred = value_sorted[int(len(distance_y.view(-1)) * 0.5)]
            mask_pos = distance_y < thred
            indices_pos = torch.nonzero(mask_pos, as_tuple=False).to(batch_x.device)

            # Find the negative elements (b, c)
            thred = value_sorted[-int(len(distance_y.view(-1)) * 0.5)]
            mask_pos = distance_y > thred
            indices_neg = torch.nonzero(~mask_pos, as_tuple=False).to(batch_x.device)

            # compute triplet loss
            # Extract the second elements (b and c)
            b_elements = indices_pos[:, 1]
            c_elements = indices_neg[:, 0]

            # Find the indices where b == c
            matching_indices = (b_elements.unsqueeze(1) == c_elements).nonzero(as_tuple=True)

            # Extract the matching pairs (a, b, b, c)
            matching_pairs = torch.cat((indices_pos[matching_indices[0]], indices_neg[matching_indices[1]]), dim=1)

            v_pos = distance_z[matching_pairs[:, 0], matching_pairs[:, 1]]
            v_neg = distance_z[matching_pairs[:, 2], matching_pairs[:, 3]]

            METRIC = (weight[matching_pairs[:, 0]] * weight[matching_pairs[:, 1]] * weight[
                matching_pairs[:, 3]] * torch.log(1 + (v_pos - v_neg).exp())).mean()

            loss = loss + METRIC

            loss.backward()
            train_loss = train_loss + loss
            optimizer.step()
            if batch_idx % 1000 == 0:
                # print('Epoch:', epoch, ' batch: ', batch_idx, '')
                Line = 'Train batch_idx: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx, batch_idx * len(batch_x), len(train_loader.dataset),
                               100. * batch_idx / len(train_loader),
                               loss.item() / len(batch_x))
                print(Line)
                # logger.info(Line)

        #
        # model.eval()
        # recon_x, _, _ = model(x.detach())
        # BCE = F.mse_loss(recon_x, x, reduction='mean')

        # print('====> Epoch: {} Average loss: {:.4f}'.format(
        #       epoch, train_loss / len(train_loader.dataset)))
        # logger.info('====> Epoch: {} reconstruction error: {:.4f} Average loss: {:.4f}'.format(
        #     epoch, BCE, train_loss / len(train_loader.dataset)))

        # model.train()

    model.eval()

def _batch_decode_z_and_props(model, z, args, filter_unique=True):
    """
    helper function to decode some latent vectors and calculate their properties
    """
    # Decode all points in a fixed decoding radius
    z_decode = []
    batch_size = 1000
    for j in range(0, len(z), batch_size):
        with torch.no_grad():
            img = model.decode_deterministic(z[j : j + batch_size])
            img = img.cpu().numpy()
            z_decode.append(img)
            del img

    # Concatentate all points and convert to numpy
    z_decode = np.concatenate(z_decode, axis=0)
    z_decode = np.around(z_decode)  # convert to int
    z_decode = z_decode[:, 0, ...]  # Correct slicing
    if filter_unique:
        z_decode, uniq_indices = np.unique(
            z_decode, axis=0, return_index=True
        )  # Unique elements only
        z = z.cpu().numpy()[uniq_indices]

    # Calculate objective function values and choose which points to keep
    if args.property_key == "areas":
        z_prop = np.sum(z_decode, axis=(-1, -2))
    else:
        raise ValueError(args.property)

    if filter_unique:
        return z_decode, z_prop, z
    else:
        return z_decode, z_prop


def latent_optimization(args, model, datamodule, num_queries_to_do):
    """ do latent space optimization with the optimal model (aka cheating) """

    unit_line = np.linspace(-args.opt_bounds, args.opt_bounds, args.opt_grid_len)
    latent_grid = list(itertools.product(unit_line, repeat=model.latent_dim))
    latent_grid = np.array(latent_grid, dtype=np.float32)
    z_latent_opt = torch.as_tensor(latent_grid, device=model.device)

    z_decode, z_prop, z = _batch_decode_z_and_props(model, z_latent_opt, args)

    z_prop_argsort = np.argsort(-1 * z_prop)  # assuming maximization of property

    # Choose new points
    new_points = z_decode[z_prop_argsort[:num_queries_to_do]]
    y_new = z_prop[z_prop_argsort[:num_queries_to_do]]
    z_query = z[z_prop_argsort[:num_queries_to_do]]

    return new_points, y_new, z_query


def latent_sampling(args, model, datamodule, num_queries_to_do, filter_unique=True):
    """ Draws samples from latent space and appends to the dataset """

    z_sample = torch.randn(num_queries_to_do, model.latent_dim, device=model.device)
    return _batch_decode_z_and_props(model, z_sample, args, filter_unique=filter_unique)


def main_loop(args):

    # Seeding
    # pl.seed_everything(args.seed)

    # Make results directory
    result_dir = Path(args.result_root).resolve()
    result_dir.mkdir(parents=True, exist_ok=True)
    data_dir = result_dir / "data"
    data_dir.mkdir(exist_ok=True)

    # Load data
    # datamodule = WeightedNumpyDataset(args, utils.DataWeighter(args))
    # datamodule.setup("fit")

    print(args.dataset_path)

    with np.load(args.dataset_path) as npz:
        all_data = npz["data"]
        all_properties = npz[args.property_key]

    # Make into tensor datasets
    datamodule = WeighNumpyDataset(all_data, all_properties)

    # Load model
    # model = ShapesVAE.load_from_checkpoint(args.pretrained_model_file)
    # model.beta = model.hparams.beta_final  # Override any beta annealing
    model1 = oldVAE.load_from_checkpoint(args.pretrained_model_file)
    model1.beta = model1.hparams.beta_final  # Override any beta annealing
    model = ShapesVAE(model1.hparams)

    # Set up results tracking
    results = dict(
        opt_points=[],
        opt_latent_points=[],
        opt_point_properties=[],
        opt_model_version=[],
        params=str(sys.argv),
        sample_points=[],
        sample_versions=[],
        sample_properties=[],
        latent_space_snapshots=[],
        latent_space_snapshot_version=[],
    )

    # Set up latent space snapshot!
    results["latent_space_grid"] = np.array(
        list(itertools.product(np.arange(-4, 4.01, 0.5), repeat=model.latent_dim)),
        dtype=np.float32,
    )

    # Set up some stuff for the progress bar
    num_retrain = int(np.ceil(args.query_budget / args.retraining_frequency))
    postfix = dict(
        retrain_left=num_retrain, best=-float("inf"), n_train=len(datamodule)
    )

    # Main loop

    class Tee:
        def __init__(self, *streams):
            self.streams = streams

        def write(self, data):
            for s in self.streams:
                s.write(data)
                s.flush()

        def flush(self):
            for s in self.streams:
                s.flush()


    with open(args.result_root+"/progress_log.txt", "w") as log_file:

        # Create a Tee object to write to both stdout and the file
        tee = Tee(sys.stdout, log_file)

        with tqdm(
            total=args.query_budget, dynamic_ncols=True, smoothing=0.0, file=tee
        ) as pbar:

            for ret_idx in range(num_retrain):
                pbar.set_postfix(postfix)
                pbar.set_description("retraining")

                # Decide whether to retrain
                samples_so_far = args.retraining_frequency * ret_idx

                # Optionally do retraining
                num_epochs = args.n_retrain_epochs
                if ret_idx == 0 and args.n_init_retrain_epochs is not None:
                    num_epochs = args.n_init_retrain_epochs
                if num_epochs > 0:
                    retrain_dir = result_dir / "retraining"
                    version = f"retrain_{samples_so_far}"
                    retrain_model(
                        model, datamodule, retrain_dir, version, num_epochs, args.gpu
                    )

                # # Draw samples for logs!
                # if args.samples_per_model > 0:
                #     pbar.set_description("sampling")
                #     sample_x, sample_y = latent_sampling(
                #         args, model, datamodule, args.samples_per_model, filter_unique=False
                #     )
                #
                #     # Append to results dict
                #     results["sample_points"].append(sample_x)
                #     results["sample_properties"].append(sample_y)
                #     results["sample_versions"].append(ret_idx)
                #
                # # Take latent snapshot
                # latent_snapshot = _batch_decode_z_and_props(
                #     model,
                #     torch.as_tensor(results["latent_space_grid"], device=model.device),
                #     args,
                #     filter_unique=False,
                # )[0]
                # results["latent_space_snapshots"].append(latent_snapshot)
                # results["latent_space_snapshot_version"].append(ret_idx)

                # Update progress bar
                postfix["retrain_left"] -= 1
                pbar.set_postfix(postfix)
                pbar.set_description("querying")

                # Do querying!
                # num_queries_to_do = min(
                #     args.retraining_frequency, args.query_budget - samples_so_far
                # )
                num_queries_to_do = 1
                if args.lso_strategy == "opt":
                    x_new, y_new, z_query = latent_optimization(
                        args, model, datamodule, num_queries_to_do
                    )
                elif args.lso_strategy == "sample":
                    x_new, y_new, z_query = latent_sampling(
                        args, model, datamodule, num_queries_to_do
                    )
                else:
                    raise NotImplementedError(args.lso_strategy)

                # # Append new points to dataset
                # datamodule.append_train_data(x_new, y_new)

                # Combine
                all_data = np.concatenate((all_data, x_new), axis=0)
                all_properties = np.concatenate((all_properties, y_new), axis=0)

                # Create new dataset
                datamodule = WeighNumpyDataset(all_data, all_properties)

                # Save a new dataset
                # new_data_file = (
                #     data_dir / f"train_data_iter{samples_so_far + num_queries_to_do}.npz"
                # )
                # np.savez_compressed(
                #     str(new_data_file),
                #     data=datamodule.data_train,
                #     **{args.property_key: datamodule.prop_train},
                # )

                # Save results
                results["opt_latent_points"] += [z for z in z_query]
                results["opt_points"] += [x for x in x_new]
                results["opt_point_properties"] += [y for y in y_new]
                results["opt_model_version"] += [ret_idx] * len(x_new)
                np.savez_compressed(str(result_dir / "results.npz"), **results)

                # Final update of progress bar
                postfix["best"] = max(postfix["best"], float(y_new.max()))
                postfix["n_train"] = len(datamodule)
                pbar.set_postfix(postfix)
                pbar.update(n=num_queries_to_do)


if __name__ == "__main__":

    # arguments and argument checking
    parser = argparse.ArgumentParser()
    parser = WeightedNumpyDataset.add_model_specific_args(parser)
    parser = utils.DataWeighter.add_weight_args(parser)
    parser = wr_base.add_common_args(parser)

    # Optimal model arguments
    opt_group = parser.add_argument_group(title="opt-model")
    opt_group.add_argument("--opt_bounds", type=float, default=3)
    opt_group.add_argument("--opt_grid_len", type=float, default=50)

    #
    # # parser.add_argument("--seed", type=float, default=1)
    # # parser.add_argument("--dataset_path", type=str, default="assets/data/shapes/squares_G64_S1-20_seed0_R10_mnc32_mxc33.npz")
    # # parser.add_argument("--property_key", type=str, default="areas")
    # parser.add_argument("--query_budget", type=int, default=500)
    # parser.add_argument("--retraining_frequency", type=int, default=5)
    # parser.add_argument("--result_root", type=str, default="logs/opt/shapes-single/rank/k_1e-3/r_5/seed1")
    # parser.add_argument("--pretrained_model_file", type=str, default="assets/pretrained_models/shapes.ckpt")
    # parser.add_argument("--weight_type", type=str, default="rank")
    # parser.add_argument("--rank_weight_k", type=float, default=1e-3)
    # parser.add_argument("--n_retrain_epochs", type=float, default=0.1)
    # parser.add_argument("--n_init_retrain_epochs", type=int, default=1)
    #
    # parser.add_argument("--lso_strategy", type=str, default="opt")

    args = parser.parse_args()

    # args = parser.parse_args(['--seed', '1',
    #                           '--dataset_path', '/Users/131227/Library/CloudStorage/OneDrive-UTS/workspace/py_workspace/bayesoptim/weighted-retraining-master/assets/data/shapes/squares_G64_S1-20_seed0_R10_mnc32_mxc33.npz',
    #                           '--property_key', 'areas',
    #                           '--query_budget', '500',
    #                           '--retraining_frequency', '5',
    #                           '--result_root', '/Users/131227/Library/CloudStorage/OneDrive-UTS/workspace/py_workspace/bayesoptim/weighted-retraining-master/logs/test/opt/shapes-single/rank/k_1e-3/r_5/seed1',
    #                           '--pretrained_model_file', '/Users/131227/Library/CloudStorage/OneDrive-UTS/workspace/py_workspace/bayesoptim/weighted-retraining-master/assets/pretrained_models/shapes.ckpt',
    #                           '--weight_type', 'rank',
    #                           '--rank_weight_k', '1e-3',
    #                           '--n_retrain_epochs', '1',
    #                           '--n_init_retrain_epochs', '1',
    #                           '--lso_strategy', 'opt'
    #                           ])

    # args = parser.parse_args(['--seed', '1',
    #                           '--dataset_path',
    #                           'assets/data/shapes/squares_G64_S1-20_seed0_R10_mnc32_mxc33.npz',
    #                           '--property_key', 'areas',
    #                           '--query_budget', '500',
    #                           '--retraining_frequency', '5',
    #                           '--result_root',
    #                           'logs/test/opt/shapes-single/rank/k_1e-3/r_5/seed1',
    #                           '--pretrained_model_file',
    #                           'assets/pretrained_models/shapes.ckpt',
    #                           '--weight_type', 'rank',
    #                           '--rank_weight_k', '1e-3',
    #                           '--n_retrain_epochs', '1',
    #                           '--n_init_retrain_epochs', '1',
    #                           '--lso_strategy', 'opt'
    #                           ])

    main_loop(args)

