import torch
import argparse
from collections import OrderedDict
import torch.nn.functional as F
import numpy as np
import sys
import pandas as pd
import matplotlib.pyplot as plt
import os
from sklearn.manifold import TSNE
import wandb

parser = argparse.ArgumentParser()
parser.add_argument("--model_ind", type=int, default=111)
parser.add_argument("--dataset", type=str, default="MNIST")
parser.add_argument("--dataset_root", type=str,
                    default="/MNIST")
parser.add_argument("--wandb_mode", type=str, default="online")

parser.add_argument("--batch_sz", type=int, default=240)

# Inv AE
parser.add_argument("--hidden_dim", default=128, type=int)
parser.add_argument("--emb_dim", default=32, type=int)

parser.add_argument("--in_channels", default=1, type=int)
parser.add_argument("--input_sz", default=28, type=int)
parser.add_argument("--ae_pass", default=0, type=int)
parser.add_argument("--hidden_dim_theta", default=64, type=int)  # Size of theta network
parser.add_argument("--emb_dim_theta", default=100, type=int)  # Size of embedding space in theta network
parser.add_argument("--use_one_layer", action='store_true', default=False)
parser.add_argument("--digit", type=int, default=0)
parser.add_argument("--scores", action='store_true', default=False)
# Theta nets
parser.add_argument("--theta_net", type=str,
                    default="EqNet")
parser.add_argument("--pretrained", action='store_true', default=False)
parser.add_argument("--pretrained_path", type=str,
                    default="./")
# Only for evaluating
parser.add_argument("--customdata_train_path", type=str,
                    default="./datasets/mnist_all_rotation_normalized_float_train_valid.amat")
parser.add_argument("--customdata_test_path", type=str,
                    default="./datasets/mnist_all_rotation_normalized_float_test.amat")
parser.add_argument("--tsne", action='store_true', default=False)
parser.add_argument("--wandb_key", type=str,
                    default="")
parser.add_argument("--net_path", type=str,
                    default="./")

parser.add_argument("--linear_classifier", action='store_true', default=False)
parser.add_argument("--classifier_epochs", type=int, default=150)

config = parser.parse_args()
config.twohead = True

if config.wandb_key:
    wandb.login(key=config.wandb_key)
wandb.init(
    project="unsup-equiv",
    config=config,
    entity="ck-experimental",
    mode=config.wandb_mode
)

if config.dataset == "RotMNIST":
    dataset_class = None
    config.ae_pass = 1

folder_name = "eval"+str(config.model_ind)
try:
    os.makedirs(folder_name)
    folder_exists = True
except:
    print("eval folder exists")
config.out_dir = folder_name+"/"

# Change these parameters

# Load net
config.dataloader_batch_sz = int(config.batch_sz)
from modules_sym import PartEqMod
net = None


if config.pretrained:
    # Don't train Inv AE when a pretrained net is passed
    net = PartEqMod(hparams=config)
    state_dict = torch.load(config.pretrained_path)

    # create new OrderedDict that does not contain `module.`
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("net.", "")  # remove "net."
        new_state_dict[name] = v

    print("Loading pretrained net")
    net.load_state_dict(new_state_dict, strict=False)
    net.cuda()
    net.eval()
# Evaluations

if config.dataset == "PartMNIST":
    from data_loading_sym import PartialMNIST_AE_Dataloader
    test_dataloader = PartialMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True, no_val_split=True)
if config.dataset == "RotMNIST":
    from data_loading_sym import RotMNIST_AE_Dataloader
    test_dataloader = RotMNIST_AE_Dataloader(config, train=False, test=True, shuffle=True, no_val_split=True)


if config.tsne:
    test_dataloader = test_dataloader[0]
    embeddings = []
    labels = []
    for x, label in test_dataloader:
        with torch.no_grad():
            net.eval()
            net.zero_grad()

            x = x.cuda()
            emb, _ = net.encoder(x)
            embeddings.append(emb.cpu())
            labels.append(label)
    embeddings = torch.cat(embeddings, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(embeddings)

    fig = plt.figure(figsize=(10, 8))
    for i in range(10):  # assuming there are 10 classes (digits) in MNIST
        indices = labels == i
        plt.scatter(embeddings_2d[indices, 0], embeddings_2d[indices, 1], alpha=0.6, label=str(i))
    plt.legend()

    # Save the figure
    fig.savefig(config.out_dir + "test_dataset_tsne.png")
    wandb.save(config.out_dir + "test_dataset_tsne.png", policy="now")
    exit()

if config.linear_classifier:
    if config.dataset == "PartMNIST":
        from data_loading_sym import PartialMNIST_AE_Dataloader

        train_dataloader = PartialMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
    if config.dataset == "RotMNIST":
        from data_loading_sym import RotMNIST_AE_Dataloader

        train_dataloader = RotMNIST_AE_Dataloader(config, train=True, test=False, shuffle=True)
    train_dataloader, val_dataloader = train_dataloader[0], train_dataloader[1]
    test_dataloader = test_dataloader[0]
    # Train a classifier with the embeddings to evaluate using accuracy
    # Freeze weights
    net.eval()
    for param in net.parameters():
        param.requires_grad = False

    # Define the linear classifier
    if "MNIST" in config.dataset:
        num_classes = 10
    else:
        raise ValueError("Define number of classes for accuracy evaluation of this dataset")

    # Classifier and optimizer are defined inside the epochs loop to train a new model from zero every epoch
    classifier = torch.nn.Linear(config.emb_dim, num_classes).cuda()
    optimiser_classifier = torch.optim.Adam(classifier.parameters(), lr=0.01)

    # Train the linear classifier with the net embeddings
    best_val = 0.
    train_iterations = 0
    for epochs in config.classifier_epochs:
        train_iterations += 1
        # Train the linear classifier with the net embeddings
        for x, label in train_dataloader:
            net.zero_grad()
            classifier.zero_grad()

            x = x.cuda()
            label = label.long().cuda()

            with torch.no_grad():
                emb, _ = net.encoder(x)

            logits = classifier(emb).float()
            classification_loss = F.cross_entropy(logits, label)

            classification_loss.backward()
            optimiser_classifier.step()

        # Validation accuracy
        correct = 0
        total = 0
        for x, label in val_dataloader:
            x = x.cuda()
            label = label.long().cuda()
            with torch.no_grad():
                emb, _ = net.encoder(x)

            logits = classifier(emb).float()

            _, predicted = torch.max(logits.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

        accuracy_val = 100. * correct / total
        print(f'Accuracy of the linear classifier on the validation set: {accuracy_val:.4f}%')
        if accuracy_val > best_val:
            best_val = accuracy_val
            wandb.log({"linear_classifier/val":accuracy_val}, step=train_iterations)
            print("Saving model with best val acc.")

            # Save model state
            net.cpu()
            net_state_dict = net.state_dict()
            torch.save(net_state_dict, config.out_dir + "best_classifier.pt")
            # Save weights to wandb
            wandb.save(config.out_dir + "best_classifier.pt", policy="now")
            net.cuda()

    # Evaluate on test set after training
    classifier = torch.nn.Linear(config.emb_dim, num_classes).cuda()
    state_dict = torch.load(config.out_dir + "best_classifier.pt")

    # Load model
    # create new OrderedDict that does not contain `module.`
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("model.", "")  # remove "model."
        new_state_dict[name] = v
    print("Loading best validation accruacy classifier")
    classifier.load_state_dict(new_state_dict, strict=False)
    classifier.cuda()
    classifier.eval()
    # Validation accuracy
    correct = 0
    total = 0
    for x, label in test_dataloader:
        x = x.cuda()
        label = label.long().cuda()
        with torch.no_grad():
            emb, _ = net.encoder(x)
            logits = classifier(emb).float()

        _, predicted = torch.max(logits.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

    accuracy_val = 100. * correct / total
    print(f'Accuracy of the linear classifier on the test set: {accuracy_val:.4f}%')

    # Unfreeze weights
    for param in net.parameters():
        param.requires_grad = True
    net.train()

    with open(config.out_dir + 'linear_classifier_results.txt', 'w') as f:
        f.write(f'Accuracy of the linear classifier on the test set: {accuracy_val:.4f}%')
        sys.stdout.flush()

    wandb.save(config.out_dir + 'linear_classifier_test_results.txt')
    exit(0)

if not config.scores:
    all_probs = None
    all_labels = None
    transformations_seen_test = None

    sample_size = 20
    count = 0
    fig, ax = plt.subplots(nrows=4, ncols=sample_size, figsize=(20, 7))
    iterators_test = (d for d in test_dataloader)
    for tup in zip(*iterators_test):
        with torch.no_grad():
            net.eval()
            net.zero_grad()
            imgs_curr = tup[0][0]  # only one here
            x = imgs_curr.cuda()
            labels = tup[0][1]

            # Calculations
            if config.pretrained:
                emb, v = net.encoder(x)
                rot = net.get_rotation_matrix(v)
                degrees_rot = net.get_degrees(rot)

                canonical_rep = net.decoder(emb).squeeze()
                recon = net.rot_img(canonical_rep, rot)
            else:
                emb, v = net.encoder(x)
                rot = net.get_rotation_matrix(v)
                degrees_rot = net.get_degrees(rot)

                canonical_rep = net.decoder(emb).squeeze()
                recon = net.rot_img(canonical_rep, rot)

            # Theta function
            # For theta layer
            if config.theta_net == "EqNet":
                degrees_theta = net.theta_function(x).squeeze()
            elif config.theta_net == "FullyConnected":
                degrees_theta = net.theta_function(emb).squeeze()

            # Orientate data
            def rot_img(x, rot, rot_inverse=False):
                if rot_inverse:
                    rot[:, 0, 1] = rot[:, 0, 1] * -1
                    rot[:, 1, 0] = rot[:, 1,
                                   0] * -1  # Inverse of a rotation is just the negative of the sin(theta) components
                    grid = F.affine_grid(rot, x.unsqueeze(1).size(), align_corners=False).type_as(x)
                    x = F.grid_sample(x.unsqueeze(1), grid, align_corners=False)
                    return x
                else:
                    grid = F.affine_grid(rot, x.size(), align_corners=False).type_as(x)
                    x = F.grid_sample(x, grid, align_corners=False)
                    return x
            oriented = rot_img(x, rot, rot_inverse=True)

            # Plot digits
            target_digit = config.digit
            for j in range(config.dataloader_batch_sz):
                if int(labels[j].item()) == target_digit:
                    ax[0, count].imshow(imgs_curr[j].cpu().squeeze(0))
                    ax[1, count].imshow(canonical_rep[j].detach().cpu().squeeze(0))
                    img_or = oriented[j].squeeze().detach().cpu()
                    ax[2, count].imshow(img_or)
                    ax[3, count].imshow(recon[j].detach().cpu().squeeze(0))
                    g = round(degrees_rot[j].item(), 1)
                    thet = round(degrees_theta[j].item(), 1)
                    ax[0,count].set_xlabel("g:"+str(g)+",\n \u03B8:"+str(thet), rotation=0)
                    count +=1
                    if count==sample_size:
                        #plt.show()
                        fig.savefig(config.out_dir+"Figure_" + str(target_digit) + ".png")
                        wandb.save(config.out_dir + "Figure_" + str(target_digit) + ".png", policy="now")
                        exit(0)

            if all_labels is None:
                all_labels = labels.cpu()
            else:
                all_labels = torch.cat((all_labels, labels.cpu()), dim=0)

            if transformations_seen_test is None:
                transformations_seen_test = rot.detach().cpu()
            else:
                transformations_seen_test = torch.cat((transformations_seen_test, rot.detach().cpu()), dim=0)

if config.scores:

    thetas_dict = {i: [] for i in range(10)}
    labels_dict = {i: [] for i in range(10)}
    all_thetas_dict = {i: [] for i in range(10)}
    psi_dict = {i: [] for i in range(10)}
    all_transforms = []
    # Insert true thetas of dataset
    #true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60.,4: 60.,
    #                    5: 90.,6: 90., 7: 90.,8: 90., 9: 90.}
    true_thetas_dict = {0: 0, 1: 18, 2: 36, 3: 54, 4: 72,
                  5: 90, 6: 108, 7: 126, 8: 144, 9: 162}
    # Example: If we want to evaluate on RotMNIST60 then set every label to 60.
    # If we want to evaluate on RotMNIST60-90 set the first 5 labels to 60 and the last 5 labels to 90
    # Otherwise scores won't make sense
    fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(20, 7))
    iterators_test = (d for d in test_dataloader)
    for tup in zip(*iterators_test):
        with torch.no_grad():
            net.eval()
            net.zero_grad()
            imgs_curr = tup[0][0]  # only one here
            x = imgs_curr.cuda()
            labels = tup[0][1]

            # Calculations
            if config.pretrained:
                emb, v = net.encoder(x)
                rot = net.get_rotation_matrix(v)
                degrees_rot = net.get_degrees(rot)

                canonical_rep = net.decoder(emb).squeeze()
                recon = net.rot_img(canonical_rep, rot)
            else:
                emb, v = net.encoder(x)
                rot = net.get_rotation_matrix(v)
                degrees_rot = net.get_degrees(rot)

                canonical_rep = net.decoder(emb).squeeze()
                recon = net.rot_img(canonical_rep, rot)

            # Theta function
            # For theta layer
            if config.theta_net == "EqNet":
                degrees_theta = net.theta_function(x).squeeze()
            elif config.theta_net == "FullyConnected":
                degrees_theta = net.theta_function(emb).squeeze()

            all_transforms.extend(list(degrees_rot.detach().squeeze().cpu().numpy()))
            for lab in range(10):
                mask = labels == lab
                sub_thetas = degrees_theta[mask]
                sub_labels = labels.squeeze().detach().cpu()[mask]
                mean_value = torch.mean(sub_thetas).cpu().item()
                thetas_dict[lab].append(mean_value)
                try:
                    #psi_dict[lab].extend(list(degrees_rot.squeeze().detach().cpu().numpy()))
                    labels_dict[lab].extend(list(sub_labels.numpy()))
                    all_thetas_dict[lab].extend(list(sub_thetas.squeeze().detach().cpu().numpy()))
                except:
                    print("Error calculating labels and thetas dict")
    with open(config.out_dir + 'results.txt', 'w') as f:
        for lab in range(10):
            nan_values = np.sum(np.isnan(thetas_dict[lab]))
            f.write("nan values: "+str(nan_values))
            mean_value = np.nanmean(thetas_dict[lab])
            f.write(f'\nMean value for label {lab}: {mean_value}\n')
            try:
                mae_value = np.nanmean(np.abs(np.array(all_thetas_dict[lab])-np.array(true_thetas_dict[lab])))
                f.write(f'\nMAE for label {lab}: {mae_value}\n')
                std_dev = np.nanstd(np.array(all_thetas_dict[lab]))
                f.write(f'\nstd dev for label {lab}: {std_dev}\n')
            except:
                print("Error calculating aux scores")
        sys.stdout.flush()


        # Print Histogram for Psi
        fig, ax = plt.subplots()
        df = pd.DataFrame()
        df["psi"] = all_transforms
        df["psi"].plot.hist(density=True, bins=65, ax=ax)
        plt.xlim(-180, 180)
        ax.set_xlabel("Angle of Group Function Prediction (º)")
        plt.title("Observed Thetas Historigram")
        plt.savefig(config.out_dir + "histogram.png")
        wandb.save(config.out_dir + "histogram.png", policy="now")
        plt.close()
    wandb.save(config.out_dir + 'results.txt')

