import os

import torch
from torch.optim import SGD, Adam

from class_separator import ClassSeparator
from tqdm import tqdm


def main():
    seed = 42
    torch.manual_seed(seed)
    device = "cuda:0"

    for num_features in [1, 2, 5, 10, 50]:
        for num_classes in [2, 5]:
            for num_mixtures in [2, 5, 10]:
                # lambda1 controls the importance of the CCNS lower bound term
                # minimizing deltaSED might lead to different class distributions
                # being completely identical, and we want to push the model away
                # from a naive solution that possibly represents a local minimum
                # in the loss landscape.
                for lambda1 in [0.0, 1.0, 5.0, 10.0]:
                    retry = True
                    while retry:
                        retry = False
                        try:

                            model_ckpt_pth = (
                                f"C_{num_classes}_M_{num_mixtures}_"
                                f"D_{num_features}_"
                                f"lambda_{lambda1}_"
                                f"seed_{seed}.pth"
                            )
                            results_ckpt_pth = (
                                f"C_{num_classes}_M_{num_mixtures}_"
                                f"D_{num_features}_"
                                f"lambda_{lambda1}_"
                                f"seed_{seed}.pt"
                            )

                            if os.path.exists(
                                os.path.join("RESULTS", results_ckpt_pth)
                            ):
                                print(
                                    "Skipping configuration due to existing results"
                                )
                                continue

                            # k, sigma, epsilon are optimized by the method
                            model = ClassSeparator(
                                num_classes=num_classes,
                                num_mixtures=num_mixtures,
                                num_features=num_features,
                                use_full_covariance=(
                                    lambda1 == 0.0 and num_features > 1
                                ),
                            )

                            epochs = 10000
                            patience = 3000
                            best_epoch = 0

                            optimizer = Adam(
                                params=model.parameters(), lr=0.01
                            )

                            load_checkpoint = False
                            if load_checkpoint:
                                p = os.path.join(
                                    "./checkpoints", model_ckpt_pth
                                )
                                if os.path.exists(p):
                                    model.load_state_dict(
                                        torch.load(
                                            p,
                                            map_location="cpu",
                                        )
                                    )

                            model.to(device)
                            model.train()

                            print(
                                f"Start training with D={num_features}, "
                                f"|C|={num_classes}, |M|={num_mixtures}, "
                                f"lambda={lambda1}..."
                            )

                            with tqdm(total=epochs) as tepoch:
                                best_deltaSED = torch.inf
                                for epoch in range(1, epochs + 1):

                                    optimizer.zero_grad()
                                    sed_tuples, lb_ccns, m_cc = model()

                                    deltaSED = 0.0

                                    # combine all SED differences for each pair of c,c'
                                    for s in sed_tuples:
                                        # We want to maximize SED_H - SED_X,
                                        # so we append a minus and minimize deltaSED
                                        deltaSED -= s[0]

                                    if lambda1 != 0.0:
                                        # compute the regularizer based on lower
                                        # bound of the CCNS
                                        loss_lb_ccns = 0.0
                                        for c_prime in range(model.C):
                                            for c in range(model.C):
                                                if c_prime == c:
                                                    # maximize intra-class ccns
                                                    loss_lb_ccns -= lb_ccns[
                                                        c, c_prime
                                                    ]
                                                else:
                                                    # minimize inter-class ccns
                                                    loss_lb_ccns += lb_ccns[
                                                        c, c_prime
                                                    ]

                                        loss = (
                                            deltaSED + lambda1 * loss_lb_ccns
                                        )
                                    else:
                                        loss = deltaSED

                                    loss.backward()

                                    optimizer.step()

                                    tepoch.set_description(f"Epoch {epoch}")
                                    tepoch.set_postfix(
                                        loss=loss.item(),
                                        deltaSED=deltaSED.item(),
                                    )
                                    tepoch.update(1)

                                    if deltaSED.item() < best_deltaSED:
                                        best_epoch = epoch
                                        best_deltaSED = (
                                            deltaSED.detach().item()
                                        )
                                        best_lb_ccns = (
                                            lb_ccns.detach()
                                            if lambda1 > 0.0
                                            else None
                                        )
                                        best_sed_tuples = (
                                            [
                                                (t[1], t[2], t[3], t[4])
                                                for t in sed_tuples
                                            ],
                                        )
                                        if not os.path.exists("./checkpoints"):
                                            os.makedirs("./checkpoints")
                                        torch.save(
                                            model.state_dict(),
                                            "./checkpoints/" + model_ckpt_pth,
                                        )

                                    if epoch - best_epoch > patience:
                                        print(
                                            f"Stopping training - early stopping"
                                        )
                                        break

                            if not os.path.exists("./RESULTS"):
                                os.makedirs("./RESULTS")

                            torch.save(
                                {
                                    "best_epoch": best_epoch,
                                    "best_sed_tuples": best_sed_tuples,
                                    "best_deltaSED": best_deltaSED,
                                    "best_lb_ccns": best_lb_ccns,
                                    "C": num_classes,
                                    "M": num_mixtures,
                                    "D": num_features,
                                    "lambda": lambda1,
                                    "seed": seed,
                                },
                                os.path.join("RESULTS", results_ckpt_pth),
                            )

                            print(f"Training of configuration has finished.")

                        except Exception as e:
                            print("Exception caught, retying...")
                            retry = True


if __name__ == "__main__":
    main()
