import os
from models.vint_v9 import VINT_V9_SA
from models.vint_v9_distilled import VINT_V9_SA_distilled
from models.vint_v9_distilled_pos import VINT_V9_SA_distilled_pos
from models.vint_v9_ssl import VINT_V9_SA_SSL
from models.siamese import SiameseModel
from models.stacked import StackedModel
from data.vint_dataset import VINT_Dataset
from data.pairwise_distance_dataset import PairwiseDistanceDataset
from training.train_utils import train_eval_loop, load_model, get_saved_optimizer
from warmup_scheduler import GradualWarmupScheduler
from torch.utils.data import DataLoader, ConcatDataset
from torch.optim import Adam, AdamW
from training.train_utils import train_eval_loop
from torchvision import transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch
import torch.nn as nn
import wandb
import argparse
import torch.backends.cudnn as cudnn
import os
import numpy as np
import yaml


def main(config):
    assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"]
    assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"]

    if torch.cuda.is_available():
        print("Num available devices:", torch.cuda.device_count())
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if "gpu_ids" not in config:
            config["gpu_ids"] = [0]
        elif type(config["gpu_ids"]) == int:
            config["gpu_ids"] = [config["gpu_ids"]]
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            [str(x) for x in config["gpu_ids"]]
        )
        print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"])
    else:
        print("Using cpu")

    first_gpu_id = config["gpu_ids"][0]
    device = torch.device(
        f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu"
    )

    if "seed" in config:
        np.random.seed(config["seed"])
        torch.manual_seed(config["seed"])
        cudnn.deterministic = True

    cudnn.benchmark = True  # good if input sizes don't vary
    transform_lst = [
        transforms.ToTensor(),
        
        # apply sky cutting
        # transforms.Lambda(lambda x: x[:, 75:, :]),
        
        transforms.Resize(config["image_size"]),
        
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        
    ]

    aug_transform = [
        transforms.ToTensor(),
        
        # apply sky cutting
        # transforms.Lambda(lambda x: x[:, 75:, :]),
        
        transforms.Resize(config["image_size"]),
        
        # apply random blurs
        transforms.Lambda(lambda x: TF.gaussian_blur(x, 3, 0.75) if np.random.rand() < 0.25 else x),
        # color jitter
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        # add random noise
        transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.02 if np.random.rand() < 0.25 else x),
        
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
    transform = transforms.Compose(transform_lst)
    aspect_ratio = config["image_size"][0] / config["image_size"][1]

    # Load the data
    train_dist_dataset = []
    train_action_dataset = []
    test_dataloaders = {}

    if "context_type" not in config:
        config["context_type"] = "temporal"
    
    if "clip_goals" not in config:
        config["clip_goals"] = False

    for dataset_name in config["datasets"]:
        data_config = config["datasets"][dataset_name]
        for data_split_type in ["train", "test"]:
            if data_split_type in data_config:

                lst = ["action"]
                if not config["distilled"]:
                    lst.extend(["distance"]) #"pairwise"])
                
                for output_type in lst:
                    
                    if data_split_type == "train":
                        t = transforms.Compose(aug_transform)
                    else:
                        t = transform
                    
                    if output_type == "pairwise":
                        dataset = PairwiseDistanceDataset(
                            data_config["data_folder"],
                            data_config[data_split_type],
                            dataset_name,
                            t,
                            aspect_ratio,
                            data_config["waypoint_spacing"],
                            config["distance"]["min_dist_cat"],
                            config["distance"]["max_dist_cat"],
                            config["close_far_threshold"],
                            data_config["negative_mining"],
                            config["context_size"],
                            config["context_type"],
                            data_config["end_slack"],
                            goal_type=config["goal_type"],
                        )
                    else:
                        dataset = VINT_Dataset(
                            data_config["data_folder"],
                            data_config[data_split_type],
                            dataset_name,
                            output_type == "action",
                            t,
                            aspect_ratio,
                            data_config["waypoint_spacing"],
                            config[output_type]["min_dist_cat"],
                            config[output_type]["max_dist_cat"],
                            data_config["negative_mining"],
                            config["len_traj_pred"],
                            config["learn_angle"],
                            config["context_size"],
                            config["context_type"],
                            data_config["end_slack"],
                            data_config["goals_per_obs"],
                            config["normalize"],
                            config["bins_provided"],
                            goal_type=config["goal_type"],
                        )
                        # subset of 100%
                        if data_split_type == "train":
                            print("Making subset of 100%")
                            dataset = torch.utils.data.Subset(dataset, range(0, int(len(dataset)*1.0)))

                    if data_split_type == "train":
                        if output_type == "distance":
                            train_dist_dataset.append(dataset)
                            print(
                                f"Loaded {len(dataset)} {dataset_name} training points"
                            )
                        elif output_type == "action":
                            print("length of dataset", dataset_name, data_split_type, 'is', len(dataset))
                            train_action_dataset.append(dataset)
                    else:
                        dataset_type = f"{dataset_name}_{data_split_type}"
                        if dataset_type not in test_dataloaders:
                            test_dataloaders[dataset_type] = {}
                        test_dataloaders[dataset_type][output_type] = dataset

    # combine all the datasets from different robots
    if train_dist_dataset:
        train_dist_dataset = ConcatDataset(train_dist_dataset)
        train_dist_loader = DataLoader(
            train_dist_dataset,
            batch_size=config["batch_size"],
            shuffle=True,
            num_workers=config["num_workers"],
            drop_last=True,
        )
    else:
        train_dist_dataset = None
        train_dist_loader = None
    
    train_action_dataset = ConcatDataset(train_action_dataset)
    
    train_action_loader = DataLoader(
        train_action_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=config["num_workers"],
        drop_last=True,
    )

    if "eval_batch_size" not in config:
        config["eval_batch_size"] = config["batch_size"]

    for dataset_type in test_dataloaders:
        for loader_type in test_dataloaders[dataset_type]:
            print('len for', dataset_type, loader_type, len(test_dataloaders[dataset_type][loader_type]))
            test_dataloaders[dataset_type][loader_type] = DataLoader(
                test_dataloaders[dataset_type][loader_type],
                batch_size=config["eval_batch_size"],
                shuffle=True,
                num_workers=config["num_workers"],
                drop_last=True,
            )

    # Create the model
    if config["model_type"] == "siamese":
        model = SiameseModel(
            config["context_size"],
            config["len_traj_pred"],
            config["learn_angle"],
            config["obs_encoding_size"],
            config["goal_encoding_size"],
        )
    elif config["model_type"] == "stacked":
        model = StackedModel(
            config["context_size"],
            config["len_traj_pred"],
            config["learn_angle"],
            config["obsgoal_encoding_size"],
        )
    elif config["model_type"] == "vint_v9_sa":
        if config["distilled"]:
            load_project_folder = os.path.join("logs", config["directed_backbone"])
            print("Loading directed backbone model from ", load_project_folder)
            latest_path = os.path.join(load_project_folder, "latest.pth")
            
            if config["pos_conditioned"]:
                model = VINT_V9_SA_distilled_pos(
                    context_size=config["context_size"],
                    len_traj_pred=config["len_traj_pred"],
                    learn_angle=config["learn_angle"],
                    obs_encoder=config["obs_encoder"],
                    goal_embedding_size=config["goal_embedding_size"],
                    obs_encoding_size=config["obs_encoding_size"],
                    mha_num_attention_heads=config["mha_num_attention_heads"],
                    mha_num_attention_layers=config["mha_num_attention_layers"],
                    mha_ff_dim_factor=config["mha_ff_dim_factor"],
                    checkpoint_path=latest_path,
                    pretrained=config["pretrained"],
                    learn_mapping=False,
                    num_categories=1,
                )
            else:
                model = VINT_V9_SA_distilled(
                    context_size=config["context_size"],
                    len_traj_pred=config["len_traj_pred"],
                    learn_angle=config["learn_angle"],
                    obs_encoder=config["obs_encoder"],
                    goal_embedding_size=config["goal_embedding_size"],
                    obs_encoding_size=config["obs_encoding_size"],
                    mha_num_attention_heads=config["mha_num_attention_heads"],
                    mha_num_attention_layers=config["mha_num_attention_layers"],
                    mha_ff_dim_factor=config["mha_ff_dim_factor"],
                    checkpoint_path=latest_path,
                    pretrained=config["pretrained"],
                    learn_mapping=config["learn_mapping"],
                    num_categories=3 if config["binned"] else 1,
                )

        else:
            model = VINT_V9_SA(
                context_size=config["context_size"],
                len_traj_pred=config["len_traj_pred"],
                learn_angle=config["learn_angle"],
                obs_encoder=config["obs_encoder"],
                goal_embedding_size=config["goal_embedding_size"],
                obs_encoding_size=config["obs_encoding_size"],
                mha_num_attention_heads=config["mha_num_attention_heads"],
                mha_num_attention_layers=config["mha_num_attention_layers"],
                mha_ff_dim_factor=config["mha_ff_dim_factor"],
            )
    else:
        raise ValueError(f"Model {config['model']} not supported")

    if len(config["gpu_ids"]) > 1:
        model = nn.DataParallel(model, device_ids=config["gpu_ids"])
    model = model.to(device)

    if config["clipping"]:
        print("Clipping gradients to", config["max_norm"])
        for p in model.parameters():
            if not p.requires_grad:
                continue
            p.register_hook(
                lambda grad: torch.clamp(
                    grad, -1 * config["max_norm"], config["max_norm"]
                )
            )

    lr = float(config["lr"])

    config["optimizer"] = config["optimizer"].lower()
    if config["optimizer"] == "adam":
        optimizer = Adam(model.parameters(), lr=lr)
    elif config["optimizer"] == "adamw":
        optimizer = AdamW(model.parameters(), lr=lr)
    elif config["optimizer"] == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    else:
        raise ValueError(f"Optimizer {config['optimizer']} not supported")
    
    scheduler = None
    if config["scheduler"] is not None:
        config["scheduler"] = config["scheduler"].lower()
        if config["scheduler"] == "cosine":
            print("Using cosine annealing with T_max", config["epochs"])
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config["epochs"]
            )
        elif config["scheduler"] == "cyclic":
            print("Using cyclic LR with cycle", config["cyclic_period"])
            scheduler = torch.optim.lr_scheduler.CyclicLR(
                optimizer,
                base_lr=lr / 10.,
                max_lr=lr,
                step_size_up=config["cyclic_period"] // 2,
                cycle_momentum=False,
            )
        elif config["scheduler"] == "plateau":
            print("Using ReduceLROnPlateau")
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=config["plateau_factor"],
                patience=config["plateau_patience"],
                verbose=True,
            )
        else:
            raise ValueError(f"Scheduler {config['scheduler']} not supported")

        if config["warmup"]:
            print("Using warmup scheduler")
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=1,
                total_epoch=config["warmup_epochs"],
                after_scheduler=scheduler,
            )

    current_epoch = 0
    if "load_run" in config:
        load_project_folder = os.path.join("logs", config["load_run"])
        print("Loading model from ", load_project_folder)
        latest_path = os.path.join(load_project_folder, "latest.pth")
        latest_checkpoint = torch.load(latest_path, map_location=device)
        load_model(model, latest_checkpoint)
        current_epoch = latest_checkpoint["epoch"] + 1
        optimizer = get_saved_optimizer(latest_checkpoint, device)
        if scheduler is not None:
           scheduler.load_state_dict(latest_checkpoint["scheduler"].state_dict())

    torch.autograd.set_detect_anomaly(True)
    if config["train"]:
        train_eval_loop(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            train_dist_loader=train_dist_loader,
            train_action_loader=train_action_loader,
            test_dataloaders=test_dataloaders,
            epochs=config["epochs"],
            device=device,
            project_folder=config["project_folder"],
            normalized=config["normalize"],
            print_log_freq=config["print_log_freq"],
            image_log_freq=config["image_log_freq"],
            num_images_log=config["num_images_log"],
            pairwise_test_freq=config["pairwise_test_freq"],
            current_epoch=current_epoch,
            learn_angle=config["learn_angle"],
            alpha=config["alpha"],
            use_wandb=config["use_wandb"],
            distilled=config.get("distilled"),
            bins_provided=config.get("bins_provided"),
        )
    print("FINISHED TRAINING")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Mobile Robot Agnostic Learning")

    # project setup
    parser.add_argument(
        "--config",
        "-c",
        default="debug.yaml",
        type=str,
        help="Any notes about the run",
    )
    args = parser.parse_args()
    with open(args.config, "r") as f:
        config = yaml.safe_load(f)

    if config["use_wandb"]:
        wandb.login()
        wandb.init(
            project=config["project_name"], settings=wandb.Settings(start_method="fork"), entity="vintv2"
        )
        wandb.run.name = config["run_name"]
        # update the wandb args with the training configurations
        if wandb.run:
            wandb.config.update(config)
        wandb.save(args.config, policy="now")  # save the config file

    config["project_folder"] = os.path.join(
        "logs", config["project_name"], config["run_name"]
    )
    if not os.path.isdir(config["project_folder"]):
        os.makedirs(config["project_folder"])

    print(config)
    main(config)
