import os
import argparse
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from source.constants import RESULTS_PATH
from source.utils.seeding import fix_seeds
from source.trainer import fit

from source.data.dots import get_standard_data
from source.data.arrow import get_data
from source.data.cityscapes import get_train_val, CityscapesDataset

from source.models.cnn import get_cnn
from source.models.resnet import get_resnet18


###############
### Parsing ###
###############

parser = argparse.ArgumentParser()
# general
parser.add_argument("--task", default="dots")
parser.add_argument("--network", default="cnn")
parser.add_argument("--method", default="ensemble")
parser.add_argument("--method_seed", default=42, type=int)
parser.add_argument("--device", default="cuda:0")
# Network
parser.add_argument("--use_natural", action="store_true")
parser.add_argument("--lr", default=1e-3, type=float)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--weight_decay", default=1e-3, type=float)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--patience", default=0, type=int)
parser.add_argument("--num_workers", default=0, type=int)
# Ensemble
parser.add_argument("--num_networks", default=10, type=int)
# MC Dropout
parser.add_argument("--p_drop", default=0.2, type=float)
# Loss parameters
parser.add_argument("--anti_regularization_weight", default=0.0, type=float)

# parse
args = parser.parse_args()

# convinience
method_seed, device = args.method_seed, args.device
print("Computation executed on >", device)
print("Use natural parametrization >", args.use_natural)

# check task
assert args.task in [
    "dots",
    "dots_gap_tail",
    "arrow",
    "arrow_gap_tail",
    "city_street",
    "city_building",
    "city_sky",
    "city_car",
    "city_vegetation",
    "git_qood"
]
# check method
assert args.method in ["ensemble", "mc_dropout"]
# check network
assert args.network in ["cnn", "resnet"]

run_path = os.path.join(
    RESULTS_PATH,
    f"{args.task}_{args.network}_natural{args.use_natural}_mseed{method_seed}_arw{args.anti_regularization_weight}",
)
os.makedirs(run_path, exist_ok=True)

# save command line arguments
formatted_args = "\n".join(f"{key}: {value}" for key, value in vars(args).items())
with open(os.path.join(run_path, "args.txt"), "w") as file:
    file.write(formatted_args)

#################
### LOAD DATA ###
#################
model_kwargs = {}
if "dots" in args.task:
    model_kwargs['in_channels'] = 1
    x_train, y_train, x_val, y_val, _, _ = get_standard_data()

    train_ds = TensorDataset(
        torch.tensor(x_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.float32),
    )
    val_ds = TensorDataset(
        torch.tensor(x_val, dtype=torch.float32),
        torch.tensor(y_val, dtype=torch.float32),
    )

elif "arrow" in args.task:
    model_kwargs['in_channels'] = 1
    x_train, y_train, x_val, y_val, _, _ = get_data()
    train_ds = TensorDataset(
        torch.tensor(x_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.float32),
    )
    val_ds = TensorDataset(
        torch.tensor(x_val, dtype=torch.float32),
        torch.tensor(y_val, dtype=torch.float32),
    )
elif "city" in args.task:
    model_kwargs['in_channels'] = 3
    target = {"street": 0, "building": 1, "sky": 2, "car": 3, "vegetation": 4}[
        args.task.split("_")[1]
    ]
    train_images, train_targets, val_images, val_targets = get_train_val(target=target)
    train_ds = CityscapesDataset(train_images, train_targets, flip=True)
    val_ds = CityscapesDataset(val_images, val_targets, flip=False)
elif "qood" in args.task:
    model_kwargs['in_channels'] = 1
    model_kwargs['image_size'] = 64
    ### TODO: resnte doesnt work due to dimension being 64x64 instead of 32x32
    from datasets.mosaic_datasets import build_id_and_ood as mosaic_ds_build_id_and_ood

    id_train, id_test, ood_fashion, ood_cifar10, ood_svhn, ood_mixture = mosaic_ds_build_id_and_ood(
        # tile_size = 32,
        seed = method_seed,
        n_id_train = 100000, # since we are not really doing mnist, gotta take a nice number
        n_id_test = 1000,
        n_ood_each = 1000,
        download = True,
        normalize_images = True,
    )
    # test set relies on a seed - since the fn does not give us a val set, we generate the second time to get the val set with seed+1
    _, id_val, _, _, _, _ = mosaic_ds_build_id_and_ood(
        # tile_size = 32,
        seed = method_seed+1,
        n_id_train = 100000, # since we are not really doing mnist, gotta take a nice number
        n_id_test = 1000,
        n_ood_each = 1000,
        download = True,
        normalize_images = True,
    )
    # set the datasets as necessary
    train_ds = id_train
    val_ds = id_val

    # print(train_ds[0][0].shape)
    print([train_ds[i][1] for i in range(20)])
    

print(len(train_ds), len(val_ds))

x, y = train_ds[0]

print(x.shape, y.shape)

####################
### LEARN MODELS ###
####################

fix_seeds(seed=method_seed)

n_networks = 1 if args.method == "mc_dropout" else args.num_networks

for n in range(n_networks):
    if args.network == "cnn":
        network = get_cnn(**model_kwargs)
    elif args.network == "resnet":
        network = get_resnet18(**model_kwargs)
    else:
        raise NotImplementedError("Network not supported")
    network.to(device)
    network.train()

    network, _, val_perfs = fit(
        network=network,
        train_loader=DataLoader(
            train_ds,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
        ),
        val_loader=DataLoader(
            val_ds,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
        ),
        device=device,
        epochs=args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        patience=args.patience,
        use_natural=args.use_natural,
        anti_regularization_weight=args.anti_regularization_weight,
    )

    os.makedirs(os.path.join(run_path, "models"), exist_ok=True)
    torch.save(network.state_dict(), os.path.join(run_path, "models", f"model_{n}.pt"))

    # save full val_perfs to file
    os.makedirs(os.path.join(run_path, "val_perfs"), exist_ok=True)
    with open(os.path.join(run_path, "val_perfs", f"model_{n}.txt"), "w") as file:
        file.write("\n".join(map(str, val_perfs)))

    # save best val_perf per model to file as text file & remove if existed previously
    if n == 0 and os.path.exists(os.path.join(run_path, f"val_perfs.txt")):
        os.remove(os.path.join(run_path, f"val_perfs.txt"))
    with open(os.path.join(run_path, f"val_perfs.txt"), "a") as file:
        file.write(f"{n}: {min(val_perfs):.4f}\n")

    # print best validation performance
    print(f"Model {n} trained with val_perf: {min(val_perfs):.4f}")
