import argparse
import os
import time
import numpy as np
import torch
from collections import OrderedDict
from modules_sym import PartEqMod
import wandb
from PIL import Image
from utils import config_to_str
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torchvision.transforms import Resize, ToTensor



# Configuration ---------------------------------------------------------------------------------------
parser = argparse.ArgumentParser()

# General
parser.add_argument("--model_ind", type=int, required=True)  # ID
parser.add_argument("--wandb_key", type=str, default="")
parser.add_argument('--seed', type=int, default=0)

# Dataset
parser.add_argument("--dataset_root", type=str,
                    default="/MNIST")
parser.add_argument("--customdata_train_path", type=str,
                    default="datasets/mnist_all_rotation_normalized_float_train_valid.amat")
parser.add_argument("--customdata_test_path", type=str,
                    default="datasets/mnist_all_rotation_normalized_float_test.amat")

# Output
parser.add_argument("--out_root", type=str,
                    default="/saves/")

# Net params
parser.add_argument("--batch_sz", type=int, default=240)  # Batch size for Inv AE training
parser.add_argument("--hidden_dim", default=128, type=int)  # Size of the networks in Inv AE
parser.add_argument("--emb_dim", default=32, type=int)  # Dimension of latent spaces
parser.add_argument("--use_one_layer", action='store_true', default=False)

# Pretrained Net
parser.add_argument("--pretrained", action='store_true', default=False)  # Call if passing a saved Inv AE model
parser.add_argument("--pretrained_path", type=str,
                    default="./")  # Path to the Inv AE model

# Theta net params
parser.add_argument("--hidden_dim_theta", default=64, type=int)  # Size of theta network
parser.add_argument("--emb_dim_theta", default=100, type=int)  # Size of embedding space in theta network

# Visualizations/Debugs
parser.add_argument("--scores", type=str)

# Logging
parser.add_argument("--wandb_mode", type=str, default="online")
parser.add_argument("--log_every", type=int, default=100)

def RotMNIST_OOD_Dataloader(config, train=False, test=True, custom_batchsize=0, shuffle=True,
                            equiv_dict=""):
    print("Loading OOD MNIST Dataset for train:",train,", and for test:",test)
    class MNISTRotationDataset(Dataset):
        def __init__(self, train=train, test=test, equiv_dict=equiv_dict):
            self.train = train
            self.test = test
            if self.train:
                self.data = np.loadtxt(config.customdata_train_path)
            elif self.test:
                self.data = np.loadtxt(config.customdata_test_path)
            self.num_samples = len(self.data)
            self.x = self.data[:, :-1].reshape(len(self.data), 28, 28)

            # Transforms
            self.resize28 = Resize(28)
            self.toTensor = ToTensor()

            self.y = self.data[:, -1]
            self.true_thetas_dict = equiv_dict

        def __len__(self):
            return self.num_samples

        def __getitem__(self, index):
            x = self.x[index]
            y = int(self.y[index])  # Convert y to integer type

            # Random rotation angle
            rotation = np.random.uniform(-180, 180)
            # Rotate the image using PIL
            imgRot = Image.fromarray(x)  # Convert to PIL Image and scale to 0-255
            imgRot = self.toTensor(self.resize28(imgRot.rotate(rotation, Image.BILINEAR)))

            # Flatten image
            imgRot = imgRot.reshape(1, 28, 28)

            # Decide the out-of-distribution label
            true_theta = self.true_thetas_dict[y]
            is_out_of_distrib = torch.tensor(0 if -true_theta <= rotation <= true_theta else 1, dtype=torch.float)

            y = torch.from_numpy(np.array(self.y[index])).float()
            return imgRot, y, is_out_of_distrib

    dataset = MNISTRotationDataset(train=train, test=test, equiv_dict=equiv_dict)
    batch_size_value = int(custom_batchsize) if custom_batchsize else config.dataloader_batch_sz
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size_value,
                                             shuffle=shuffle,
                                             num_workers=0,
                                             drop_last=False)
    return [dataloader]


def main():
    config = parser.parse_args()

    # Set seed
    if config.seed == -1:
        config.seed = np.random.randint(0, 100000)
    pl.seed_everything(config.seed)

    # Setup ------------------------------------------------------------------------

    config.out_dir = config.out_root + str(config.model_ind) + "/"
    config.dataloader_batch_sz = int(config.batch_sz)
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    folder_name = f"exp_{timestamp}"
    os.makedirs("saves/"+folder_name)
    config.out_dir = "saves/"+folder_name+"/"

    if not os.path.exists(config.out_dir):
        os.makedirs(config.out_dir)

    print("Config: %s" % config_to_str(config))

    # Initialize wandb
    if config.wandb_key:
        wandb.login(key=config.wandb_key)
    wandb.init(
        project="unsup-equiv",
        config=config,
        entity="ck-experimental",
        mode=config.wandb_mode
    )

    # Model ------------------------------------------------------------------------
    net = PartEqMod(hparams=config)
    state_dict = torch.load(config.pretrained_path)

    # create new OrderedDict that does not contain `module.`
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k.replace("model.", "")  # remove "model."
        new_state_dict[name] = v

    keys_to_load = {k: v for k, v in new_state_dict.items()}
    print("Loading pretrained model")
    net.load_state_dict(keys_to_load, strict=False)
    net.cuda()
    net.eval()

    # Insert true thetas of dataset
    if config.scores == "MNISTRot60":
        true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60., 4: 60.,
                            5: 60., 6: 60., 7: 60., 8: 60., 9: 60.}
    elif config.scores == "MNISTRot60-90":
        true_thetas_dict = {0: 60., 1: 60., 2: 60., 3: 60., 4: 60.,
                            5: 90., 6: 90., 7: 90., 8: 90., 9: 90.}
    elif config.scores == "MNISTMultiple":
        true_thetas_dict = {0: 0, 1: 18, 2: 36, 3: 54, 4: 72,
                            5: 90, 6: 108, 7: 126, 8: 144, 9: 162}
    elif config.scores == "MNIST":
        true_thetas_dict = {0: 0., 1: 0., 2: 0., 3: 0., 4: 0.,
                            5: 0., 6: 0., 7: 0., 8: 0., 9: 0.}
    elif config.scores == "MNISTRot":
        true_thetas_dict = {0: 180., 1: 180., 2: 180., 3: 180., 4: 180.,
                            5: 180., 6: 180., 7: 180., 8: 180., 9: 180.}
    else:
        ValueError("Dataset not supported.")

    # Load labelled data
    test_dataloader = RotMNIST_OOD_Dataloader(config,
                                              equiv_dict=true_thetas_dict, train=False, test=True, shuffle=True)
    test_dataloader = test_dataloader[0]

    # Initialize counters
    correct_predictions = 0
    total_predictions = 0
    ood_degrees = []
    for x, label, ood_label in test_dataloader:
        x = x.cuda().squeeze(1)
        label = label.long().cuda()
        ood_label = ood_label.cuda()

        with torch.no_grad():
            # Encoder pass
            emb, v = net.encoder(x)
            rot = net.get_rotation_matrix(v)
            degrees_rot = net.get_degrees(rot)

            # Theta function
            degrees_theta = net.theta_function(x).squeeze()

        # Out-of-distribution symmetry detector
        is_out_of_distribution = (degrees_rot.abs() > degrees_theta).float()

        # Update counters
        correct_predictions += (is_out_of_distribution == ood_label.float()).sum().item()
        total_predictions += ood_label.size(0)

        # Store degrees of rotation for OOD samples
        ood_degrees.extend(degrees_rot[is_out_of_distribution == 1].cpu().numpy())


    # Compute accuracy
    accuracy = (correct_predictions / total_predictions) * 100
    print(f"Accuracy of OOD classifier in {config.scores}: {accuracy:.4f}")


if __name__ == "__main__":
    main()