import numpy as np
import torch
import argparse
import losses
import spaces
import disentanglement_utils
import invertible_network_utils
import torch.nn.functional as F
import random
import os
import latent_spaces
import encoders
from torch.utils.data import Dataset, DataLoader
from plot_utils import plot_latents, plot_probs, plot_critic_sphere, save_gif, generate_grid
from torch.utils.tensorboard import SummaryWriter


# use_cuda = torch.cuda.is_available()
# if use_cuda:
#     device = "cuda:0"
# else:
#     device = "cpu"


def str2bool(x):
    if x.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif x.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args():
    parser = argparse.ArgumentParser(
        description="Disentanglement with Contrastive Learning - MLP Mixing"
    )
    parser.add_argument(
        '--gpu', type=int, default=0,
        help='used GPU if cuda is available and set to True'
    )
    parser.add_argument("--sphere-r", type=float, default=1.0)
    parser.add_argument(
        "--box-min",
        type=float,
        default=0.0,
        help="For box normalization only. Minimal value of box.",
    )
    parser.add_argument(
        "--box-max",
        type=float,
        default=1.0,
        help="For box normalization only. Maximal value of box.",
    )
    # parser.add_argument(
    #     "--sphere-norm", action="store_true", help="Normalize output to a sphere."
    # )
    # parser.add_argument(
    #     "--box-norm", action="store_true", help="Normalize output to a box."
    # )
    parser.add_argument(
        "--output-norm", default=None,#'fixed_sphere',
        choices=[None, 'learnable_box', 'fixed_box', 'learnable_sphere', 'fixed_sphere'],
        help="Normalize output."
    )
    parser.add_argument(
        "--only-supervised", action="store_true", help="Only train supervised model."
    )
    # parser.add_argument(
    #     "--only-unsupervised",
    #     action="store_true",
    #     help="Only train unsupervised model.",
    # )
    parser.add_argument(
        "--only-unsupervised",
        default=True,
        help="Only train unsupervised model.",
    )
    parser.add_argument(
        "--more-unsupervised",
        type=int,
        default=3,
        help="How many more steps to do for unsupervised compared to supervised training.",
    )
    parser.add_argument("--save-dir", type=str, default="")
    parser.add_argument(
        "--num-eval-batches",
        type=int,
        default=10,
        help="Number of batches to average evaluation performance at the end.",
    )
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument(
        "--act-fct",
        type=str,
        default="leaky_relu",
        help="Activation function in mixing network g.",
    )
    parser.add_argument(
        "--c-param",
        type=float,
        default=0.1,
        help="Concentration parameter of the conditional distribution.",
    )
    parser.add_argument(
        "--m-param",
        type=float,
        default=1.0,
        help="Additional parameter for the marginal (only relevant if it is not uniform).",
    )
    parser.add_argument("--tau", type=float, default=0.1)
    parser.add_argument(
        "--n-mixing-layer",
        type=int,
        default=3,
        help="Number of layers in nonlinear mixing network g.",
    )
    parser.add_argument(
        "--n", type=int, default=10, help="Dimensionality of the latents."
    )
    parser.add_argument(
        "--space-type", type=str, default="box", choices=("box", "sphere", "unbounded", "hollow_ball", "cube_grid")
    )
    parser.add_argument(
        "--m-p",
        type=int,
        default=0,
        help="Type of ground-truth marginal distribution. p=0 means uniform; "
        "all other p values correspond to (projected) Lp Exponential",
    )
    parser.add_argument(
        "--c-p",
        nargs="*",
        type=float,
        default=[1],
        help="Exponent(s) of ground-truth Lp Exponential distribution. Make sure that len(c-p) in [1, n]",
    )
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument(
        "--p",
        type=float,
        default=1,
        help="Exponent of the assumed model Lp Exponential distribution. p=-1 means the exponents are learnable parameters.",
    )
    parser.add_argument(
        "--loss",
        type=str,
        default='ince',
        choices=('ince', 'nce', 'nwj', 'scl', 'simclr'),
        help="Loss function to minimize (only used if p==-1)",
    )
    parser.add_argument(
        "--encoder",
        type=str,
        default='res',
        choices=('mlp', 'res'),
        help="Encoder architecture",
    )
    parser.add_argument(
        "--margin-mode",
        type=str,
        default='second',
        choices=('first', 'second', 'both'),
        help="Encoder architecture",
    )
    parser.add_argument(
        "--center",
        type=str2bool,
        default=True,
        help="Whether to add additional loss to center the representation. Be careful with space constraint!",
    )
    parser.add_argument("--batch-size", type=int, default=5120)
    parser.add_argument("--n-log-steps", type=int, default=10000)
    parser.add_argument("--n-steps", type=int, default=100001)
    parser.add_argument("--resume-training", action="store_true")
    parser.add_argument(
        "--early-stopping", default=False, help="Stop early if disentanglement score is high enough."
    )

    args = parser.parse_args()
    
    # loss = args.loss if args.p == -1 else "lp" + str(args.p)
    loss = args.loss if args.p == -1 else args.loss + "lp" + str(int(args.p))
    c = '1' if args.center else '0'
    args.run = 'runs/' + 'seed' + str(args.seed) +\
        '/' + args.space_type.lower() + '_m' + str(args.m_p) + '_p' + str(int(args.c_p[0])) + '_n' + str(args.n) + \
        '/' + loss + '_bs' + str(args.batch_size) + '_c' + c

    print("Arguments:")
    for k, v in vars(args).items():
        print(f"\t{k}: {v}")

    return args


def main():
    args = parse_args()
    if torch.cuda.is_available() and args.gpu is not None:
        device = 'cuda:' + str(args.gpu)
    else:
        device = "cpu"
    writer = SummaryWriter(log_dir=args.run)
    writer.add_text("Options", str(args))

    if args.seed is not None:
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
    if args.space_type == "box":
        space = spaces.NBoxSpace(args.n, args.box_min, args.box_max)
    elif args.space_type == "sphere":
        space = spaces.NSphereSpace(args.n, args.sphere_r)
    elif args.space_type == "hollow_ball":
        space = spaces.NHollowBallSpace(args.n)
    elif args.space_type == "cube_grid":
        space = spaces.NCubeGridSpace(args.n)
    else:
        space = spaces.NRealSpace(args.n)
    
    if args.p == 0:
        loss = losses.SimCLRLoss(normalize=False, tau=args.tau)
    else:
        if args.loss == 'ince':
            loss = losses.DeltaINCELoss(args.n, args.space_type, args.p, args.tau, args.margin_mode, args.center, device)
        elif args.loss == 'nce':
            loss = losses.DeltaNCELoss(args.n, args.space_type, args.p, args.tau, args.margin_mode, args.center, device)
        elif args.loss == 'nwj':
            loss = losses.DeltaNWJLoss(args.n, args.space_type, args.p, args.tau, args.margin_mode, args.center, device)
        elif args.loss == 'scl':
            loss = losses.DeltaSCLLoss(args.n, args.space_type, args.p, args.tau, args.margin_mode, args.center, device)
        elif args.loss == 'simclr':
            loss = losses.LpSimCLRLoss(
                p=args.p, tau=args.tau, simclr_compatibility_mode=True
            )
        else:
            raise NotImplementedError

    eta = torch.zeros(args.n)
    if args.space_type == "sphere":
        eta[0] = 1.0

    # Setup marginal distribution
    if args.m_p == 1:
        sample_marginal = lambda space, size, device=device: space.laplace(
            eta, args.m_param, size, device
        )
    elif args.m_p == 2:
        sample_marginal = lambda space, size, device=device: space.normal(
            eta, args.m_param, size, device
        )
    elif args.m_p > 0:
        sample_marginal = lambda space, size, device=device: space.generalized_normal(
            eta, args.m_param, p=args.m_p, size=size, device=device
        )
    elif args.m_p == 0:
        sample_marginal = lambda space, size, device=device: space.uniform(
            size, device=device
        )
    else:
        sample_marginal = lambda space, size, device=device: space.non_uniform(
            1.0/args.m_param, size, device
        )
    
    # Setup conditional distribution
    if args.m_p < 0 and len(args.c_p) == 1:
        sample_conditional = lambda space, z, size, device=device: space.generalized_normal_with_checkerboard_pattern(
            z, args.c_param, p=args.c_p[0], size=size, device=device
        )
    elif len(args.c_p) > 1:
        sample_conditional = (
            lambda space, z, size, device=device: space.mixed_generalized_normal(
                z, args.c_param, p=args.c_p, size=size, device=device
            )
        )
    elif args.c_p[0] == 1:
        sample_conditional = lambda space, z, size, device=device: space.laplace(
            z, args.c_param, size, device
        )
    elif args.c_p[0] == 2:
        sample_conditional = lambda space, z, size, device=device: space.normal(
            z, args.c_param, size, device
        )
    elif args.c_p[0] > 0:
        sample_conditional = lambda space, z, size, device=device: space.generalized_normal(
            z, args.c_param, p=args.c_p[0], size=size, device=device
        )
    else:
        sample_conditional = lambda space, z, size, device=device: space.von_mises_fisher(
            z, 1.0/args.c_param, size, device
        )

    latent_space = latent_spaces.LatentSpace(
        space=space,
        sample_marginal=sample_marginal,
        sample_conditional=sample_conditional,
    )

    # def sample_marginal_and_conditional(size, device=device):
    #     z = latent_space.sample_marginal(size=size, device=device)
    #     z_tilde = latent_space.sample_conditional(z, size=size, device=device)
    #     return z, z_tilde

    # Use dataloader for parallelization (otherwise too slow for strong rejection sampling)
    class SyntheticDataset(Dataset):
        def __init__(self, size):
            self.size = size
        
        def __len__(self):
            return np.iinfo(np.int64).max

        def __getitem__(self, idx):
            z = latent_space.sample_marginal(size=self.size, device='cpu')
            z_tilde = latent_space.sample_conditional(z, size=self.size, device='cpu')
            return z, z_tilde
    
    dataset = SyntheticDataset(args.batch_size)
    dataloader = iter(DataLoader(dataset, batch_size=1, num_workers=8))
    
    g = invertible_network_utils.construct_invertible_mlp(
        n=args.n,
        n_layers=args.n_mixing_layer,
        act_fct=args.act_fct,
        cond_thresh_ratio=0.0,
        n_iter_cond_thresh=25000
    )
    g = g.to(device)

    for p in g.parameters():
        p.requires_grad = False

    h_ind = lambda z: g(z)

    z_disentanglement = latent_space.sample_marginal(4096)

    (linear_disentanglement_score, _), _ = disentanglement_utils.linear_disentanglement(
        z_disentanglement, h_ind(z_disentanglement), mode="r2"
    )
    print(f"Id. Lin. Disentanglement: {linear_disentanglement_score:.4f}")
    (
        permutation_disentanglement_score,
        _,
    ), _ = disentanglement_utils.permutation_disentanglement(
        z_disentanglement,
        h_ind(z_disentanglement),
        mode="pearson",
        solver="munkres",
        rescaling=True,
    )
    print(f"Id. Perm. Disentanglement: {permutation_disentanglement_score:.4f}")

    def unpack_item_list(lst):
        if isinstance(lst, tuple):
            lst = list(lst)
        result_list = []
        for it in lst:
            if isinstance(it, (tuple, list)):
                result_list.append(unpack_item_list(it))
            else:
                result_list.append(it.item())
        return result_list

    if args.save_dir:
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        torch.save(g.state_dict(), os.path.join(args.save_dir, "g.pth"))
    if args.only_unsupervised:
        test_list = [False]
    elif args.only_supervised:
        test_list = [True]
    else:
        test_list = [True, False]
    for test in test_list:
        print("supervised test: {}".format(test))

        if args.encoder == 'mlp':
            f = encoders.get_mlp(
                n_in=args.n,
                n_out=args.n,
                layers=[
                    args.n * 10,
                    args.n * 50,
                    args.n * 50,
                    args.n * 50,
                    args.n * 50,
                    args.n * 10,
                ],
                output_normalization=args.output_norm,
            )
        elif args.encoder == 'res':
            f = encoders.get_resnet(
                n_in=args.n,
                n_out=args.n,
                layers=[
                    args.n * 10,
                    args.n * 20,
                    args.n * 20,
                    args.n * 20,
                    args.n * 20
                ],
                layer_normalization='bn',
                output_normalization=args.output_norm,
            )
        else:
            raise NotImplementedError
        f = f.to(device)
        print("f: ", f)

        def train_step(data, loss, optimizer):
            z1, z2_con_z1 = data
            z1 = z1.to(device)
            z2_con_z1 = z2_con_z1.to(device)

            # create random "negative" pairs
            # this is faster than sampling z3 again from the marginal distribution
            # and should also yield samples as if they were sampled from the marginal
            z3 = torch.roll(z1, 1, 0)

            optimizer.zero_grad()

            z1_rec = h(z1)
            z2_con_z1_rec = h(z2_con_z1)
            z3_rec = torch.roll(z1_rec, 1, 0)

            if test:
                total_loss_value = F.mse_loss(z1_rec, z1)
                losses_value = [total_loss_value]
            else:
                total_loss_value, _, losses_value = loss(
                    z1, z2_con_z1, z3, z1_rec, z2_con_z1_rec, z3_rec
                )

            total_loss_value.backward()
            optimizer.step()

            return total_loss_value.item(), unpack_item_list(losses_value)

        if args.loss == 'simclr':
            optimizer = torch.optim.Adam(f.parameters(), lr=args.lr)
        else:
            optimizer = torch.optim.Adam([
                {'params': f.parameters(), 'lr': args.lr},
                {'params': [p for (n, p) in loss.critic.named_parameters() if n != 'c'], 'lr': 100.0 * args.lr},
                {'params': loss.critic.c, 'lr': args.lr}
            ], lr=args.lr)

        h = lambda z: f(g(z))
        
        if (
            "total_loss_values" in locals() and not args.resume_training
        ) or "total_loss_values" not in locals():
            individual_losses_values = []
            total_loss_values = []
            linear_disentanglement_scores = [0]
            permutation_disentanglement_scores = [0]
            latent_maps = []
            critic_maps = []

        global_step = len(total_loss_values) + 1
        best_step = np.inf
        while (
            global_step <= args.n_steps
            if test
            else global_step <= (args.n_steps * args.more_unsupervised)
        ):
            # data = sample_marginal_and_conditional(size=args.batch_size)
            data = [x.squeeze(0).to(device=device) for x in next(dataloader)]
            total_loss_value, losses_value = train_step(
                data, loss=loss, optimizer=optimizer
            )
            total_loss_values.append(total_loss_value)
            individual_losses_values.append(losses_value)
            if global_step % args.n_log_steps == 1 or global_step == args.n_steps:
                z_disentanglement = latent_space.sample_marginal(4096)
                (
                    linear_disentanglement_score,
                    _,
                ), _ = disentanglement_utils.linear_disentanglement(
                    z_disentanglement, h(z_disentanglement), mode="r2"
                )
                linear_disentanglement_scores.append(linear_disentanglement_score)
                (
                    permutation_disentanglement_score,
                    _,
                ), _ = disentanglement_utils.permutation_disentanglement(
                    z_disentanglement,
                    h(z_disentanglement),
                    mode="pearson",
                    solver="munkres",
                    rescaling=True,
                )
                permutation_disentanglement_scores.append(
                    permutation_disentanglement_score
                )

                # print critic params
                if args.loss != 'simclr':
                    critic_params = loss.critic.get_param()
                    for key, val in critic_params.items():
                        val = val.view(-1).numpy()
                        print(key, val)
                        for i, e in enumerate(val):
                            writer.add_scalar(f"{key}/{i}", e, global_step)

                # show latent map and critic f1/f2
                with torch.no_grad():
                    def save_model(g, f, d, file_path, silent=False):
                        states = {
                            'g': g.state_dict(),
                            'f': f.state_dict(),
                            'd': d.state_dict()
                        }
                        with open(file_path, mode='wb+') as f:
                            torch.save(states, f)
                        if not silent:
                            print("Saved checkpoint '{}'".format(file_path))

                    # save_model(g, f, loss.critic, os.path.join(args.run, 'last'))

                    if args.n <= 3:
                        latent_grid = generate_grid(args, num_steps=101)  # 8, 21, 101
                        # rng = torch.linspace(args.box_min, args.box_max, steps=101)
                        # latent_grid = [m.contiguous().view(-1, 1) for m in torch.meshgrid(*([rng]*2))]
                        # mid = [0.5*torch.ones(101, 101).view(-1, 1)]
                        # latent_grid = torch.cat([torch.cat(latent_grid + mid, dim=-1), torch.cat(mid + latent_grid, dim=-1)], dim=0)

                        latent_grid = latent_grid.to(z_disentanglement.device)
                        features = h(latent_grid)
                        features_centered = features# - torch.mean(features, dim=0, keepdim=True)
                        latent_grid = latent_grid.cpu().numpy()
                        features_centered = features_centered.cpu().numpy()
                        latent_map = plot_latents(latent_grid, features_centered, global_step / args.n_log_steps)
                        latent_maps.append(latent_map)
                        if global_step % (2 * args.n_log_steps) == 1:
                            os.makedirs(args.run, exist_ok=True)
                            save_gif(latent_maps, os.path.join(args.run, 'latent_map.gif'))
                    
                    if args.n == 2 and args.space_type != 'sphere':
                        f1 = loss.critic.f1(features).cpu().numpy()
                        f2 = loss.critic.f2(features).cpu().numpy()
                        c = loss.critic.c.cpu().numpy()
                        critic_map = plot_probs(latent_grid, f1, f2, c, args, global_step / args.n_log_steps)
                        critic_maps.append(critic_map)
                        if global_step % (5 * args.n_log_steps) == 1:
                            os.makedirs(args.run, exist_ok=True)
                            save_gif(critic_maps, os.path.join(args.run, 'critic_map.gif'))
                    
                    if args.n == 3 and args.space_type == 'sphere':
                        f1 = loss.critic.f1(features).cpu().numpy()
                        f2 = loss.critic.f2(features).cpu().numpy()
                        print(np.min(f1), np.max(f1))
                        print(np.min(f2), np.max(f2))
                        features = features.cpu().numpy()
                        c = loss.critic.c.cpu().numpy()
                        critic_map = plot_critic_sphere(latent_grid, features, f1, f2, c, args, global_step / args.n_log_steps)
                        critic_maps.append(critic_map)
                        if global_step % (5 * args.n_log_steps) == 1:
                            os.makedirs(args.run, exist_ok=True)
                            save_gif(critic_maps, os.path.join(args.run, 'critic_map.gif'))
            
            else:
                linear_disentanglement_scores.append(linear_disentanglement_scores[-1])
                permutation_disentanglement_scores.append(
                    permutation_disentanglement_scores[-1]
                )
            if global_step % args.n_log_steps == 1 or global_step == args.n_steps:
                center_dist = torch.sum(torch.mean(h(z_disentanglement), dim=0).pow(2), dim=-1)
                print(
                    f"Step: {global_step} \t",
                    f"Loss: {total_loss_value:.4f} \t",
                    f"<Loss>: {np.mean(np.array(total_loss_values[-args.n_log_steps:])):.4f} \t",
                    f"Lin. Disentanglement: {linear_disentanglement_score:.4f} \t",
                    f"Perm. Disentanglement: {permutation_disentanglement_score:.4f}",
                    f"Center Distance: {center_dist:.4f}",
                )
                writer.add_scalar('Loss', total_loss_value, global_step)
                writer.add_scalar('Lin. Disentanglement', linear_disentanglement_score, global_step)
                writer.add_scalar('Perm. Disentanglement', permutation_disentanglement_score, global_step)
                writer.add_scalar('Center Distance', center_dist, global_step)
                if args.output_norm == 'learnable_sphere':
                    print(f"r: {f[-1].r}")
                if args.early_stopping and permutation_disentanglement_score > 0.99:
                    best_step = min(global_step, best_step)
                    if best_step + 20000 < global_step:
                        break

            global_step += 1
        if args.save_dir:
            if not os.path.exists(args.save_dir):
                os.makedirs(args.save_dir)
            torch.save(
                f.state_dict(),
                os.path.join(
                    args.save_dir, "{}_f.pth".format("sup" if test else "unsup")
                ),
            )
        # torch.cuda.empty_cache()  # Leads to crash (not enough memory)
    final_linear_scores = []
    final_perm_scores = []
    with torch.no_grad():
        for i in range(args.num_eval_batches):
            # data = sample_marginal_and_conditional(args.batch_size)
            data = [x.squeeze(0).to(device=device) for x in next(dataloader)]
            z1, z2_con_z1 = data
            z1 = z1.to(device)
            z2_con_z1 = z2_con_z1.to(device)
            z3 = torch.roll(z1, 1, 0)
            z1_rec = h(z1)
            z2_con_z1_rec = h(z2_con_z1)
            z3_rec = h(z3)
            (
                linear_disentanglement_score,
                _,
            ), _ = disentanglement_utils.linear_disentanglement(z1, z1_rec, mode="r2")
            (
                permutation_disentanglement_score,
                _,
            ), _ = disentanglement_utils.permutation_disentanglement(
                z1, z1_rec, mode="pearson", solver="munkres", rescaling=True
            )
            final_linear_scores.append(linear_disentanglement_score)
            final_perm_scores.append(permutation_disentanglement_score)
    print(
        "linear mean: {} std: {}".format(
            np.mean(final_linear_scores), np.std(final_linear_scores)
        )
    )
    print(
        "perm mean: {} std: {}".format(
            np.mean(final_perm_scores), np.std(final_perm_scores)
        )
    )


if __name__ == "__main__":
    main()
