import os
import argparse
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from source.constants import RESULTS_PATH

from source.data.dots import get_standard_data, _generate_data
from source.data.arrow import get_data
from source.data.cityscapes import get_test, CityscapesDataset

from source.models.cnn import get_cnn
from source.models.resnet import get_resnet18
from source.models.utils import variance_link, natural_link, natural_to_gauss


parser = argparse.ArgumentParser()
parser.add_argument("--task", default="dots")
parser.add_argument("--subset", default="")
parser.add_argument("--network", default="cnn")
parser.add_argument("--method", default="ensemble")
parser.add_argument("--method_seed", default=42, type=int)
parser.add_argument("--use_natural", action="store_true")
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--batch_size", default=128, type=int)

# Loss parameters
parser.add_argument("--anti_regularization_weight", default=0.0, type=float)

# parse
args = parser.parse_args()

# convinience
device = args.device
print("Evaluation executed on >", device)

# check task
assert args.task in [
    "dots",
    "dots_gap_tail",
    "arrow",
    "arrow_gap_tail",
    "city_street",
    "city_building",
    "city_sky",
    "city_car",
    "city_vegetation",
    "git_qood"
]
# check method
assert args.method in ["ensemble", "mc_dropout"]
# check network
assert args.network in ["cnn", "resnet"]

# subset = '_'.join(args.task.split('_')[2:]) if len(args.task.split('_')) >= 3 else ''
# main_set = '_'.join(args.task.split('_')[:2])
# args.task = main_set
full_subset = args.subset
subset = args.subset.split('_')[0]
subset_args = args.subset.split('_')[1:]

run_path = os.path.join(
    RESULTS_PATH,
    f"{args.task}_{args.network}_natural{args.use_natural}_mseed{args.method_seed}_arw{args.anti_regularization_weight}",
)

image_size = None
if args.task == "dots":
    in_channels = 1
    _, _, _, _, x_test, y_test = get_standard_data()
    test_ds = TensorDataset(
        torch.tensor(x_test, dtype=torch.float32),
        torch.tensor(y_test, dtype=torch.float32),
    )
elif args.task == "dots_gap_tail":
    in_channels = 1
    x_test, y_test = _generate_data(
        n_samples=8200, seed=7532
    )  # generate new data for testing on whole range
    x_test = x_test[:, None, :, :]
    y_test = y_test.astype(np.float32) / 50
    test_ds = TensorDataset(
        torch.tensor(x_test, dtype=torch.float32),
        torch.tensor(y_test, dtype=torch.float32),
    )
elif args.task == "arrow":
    in_channels = 1
    _, _, _, _, x_test, y_test = get_data()
    test_ds = TensorDataset(
        torch.tensor(x_test, dtype=torch.float32),
        torch.tensor(y_test, dtype=torch.float32),
    )
elif "city" in args.task:
    in_channels = 3
    target = {"street": 0, "building": 1, "sky": 2, "car": 3, "vegetation": 4}[
        args.task.split("_")[1]
    ]
    x_test, y_test = get_test(target=target)
    test_ds = CityscapesDataset(x_test, y_test, flip=False)
elif "qood" in args.task:
    ## TODO: change to model_kwargs
    in_channels = 1
    image_size = 64
    ### TODO: resnet doesnt work due to dimension being 64x64 instead of 32x32
    from datasets.mosaic_datasets import build_id_and_ood as mosaic_ds_build_id_and_ood, build_id_and_ood_emnist_one_at_a_time
    if "emnist" in subset:
        # interpret the subset_args as ood_positions
        ood_positions = [int(i) for i in list(subset_args)]
        print('ood positions',ood_positions)
        test_ds = build_id_and_ood_emnist_one_at_a_time(
            # tile_size = 32,
            seed = args.method_seed,
            n_id_train = 100000, # since we are not really doing mnist, gotta take a nice number
            n_id_test = 1000,
            n_ood_each = 1000,
            download = True,
            normalize_images = True,
            ood_positions = ood_positions
        )

        y_test = [t[1] for t in test_ds]
        x_test = [t[0] for t in test_ds]

    else:
        id_train, id_test, ood_fashion, ood_cifar10, ood_svhn, ood_mixture = mosaic_ds_build_id_and_ood(
            # tile_size = 32,
            seed = args.method_seed,
            n_id_train = 100000, # since we are not really doing mnist, gotta take a nice number
            n_id_test = 1000,
            n_ood_each = 1000,
            download = True,
            normalize_images = True,
        )
        if "test" in subset:
            test_ds = id_test
        if "mixture" in subset:
            test_ds = ood_mixture
        elif "fashion" in subset:
            test_ds = ood_fashion
        elif "cifar" in subset:
            test_ds = ood_cifar10
        elif "svhn" in subset:
            test_ds = ood_svhn
        
        y_test = [t[1] for t in test_ds]
        x_test = [t[0] for t in test_ds]

    # print(train_ds[0][0].shape)
   


torch.save(
    torch.tensor(y_test, dtype=torch.float32), os.path.join(run_path, f"y_test{full_subset}.pt")
)

print(len(test_ds))

x, y = test_ds[0]

print(x.shape, y.shape)

n_networks = (
    1
    if args.method == "mc_dropout"
    else len(os.listdir(os.path.join(run_path, "models")))
)

means, variances = [], []
for n in tqdm(range(n_networks)):
    if args.network == "cnn":
        network = get_cnn(in_channels, image_size=image_size)
    elif args.network == "resnet":
        network = get_resnet18(in_channels, image_size=image_size)
    else:
        raise NotImplementedError("Network not supported")

    network.to(device)
    network.load_state_dict(
        torch.load(
            os.path.join(run_path, "models", f"model_{n}.pt"), map_location=device
        )
    )
    network.eval()

    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)

    for x, y in test_loader:
        x = x.to(device)

        with torch.no_grad():
            y_pred = network.forward(x).cpu()

        if args.use_natural:
            mean, var = natural_to_gauss(*natural_link(y_pred))
        else:
            mean, var = variance_link(y_pred)

        means.append(mean)
        variances.append(var)

means = torch.concat(means, dim=0).reshape(n_networks, -1).moveaxis(0, -1)
variances = torch.concat(variances, dim=0).reshape(n_networks, -1).moveaxis(0, -1)

preds = torch.stack([means, variances], dim=-1)

torch.save(preds, os.path.join(run_path, f"preds{full_subset}.pt"))
