import sys
from pathlib import Path
import yaml
import json
import numpy as np
import torch
import click
import argparse
import tqdm
from torch.nn.parallel import DistributedDataParallel as DDP

from utils import distributed
import utils.torch as ptu
import config

from model.factory import create_segmenter
from optim.factory import create_optimizer, create_scheduler
from data.factory import create_dataset
from model.utils import num_params

from timm.utils import NativeScaler
from contextlib import suppress

from utils.distributed import sync_model
from engine import train_one_epoch, evaluate


@click.command(help="")
@click.option("--log-dir", type=str, help="logging directory")
@click.option("--dataset", type=str)
@click.option("--im-size", default=None, type=int, help="dataset resize size")
@click.option("--crop-size", default=None, type=int)
@click.option("--window-size", default=None, type=int)
@click.option("--window-stride", default=None, type=int)
@click.option("--backbone", default="", type=str)
@click.option("--decoder", default="", type=str)
@click.option("--optimizer", default="sgd", type=str)
@click.option("--scheduler", default="polynomial", type=str)
@click.option("--weight-decay", default=0.0, type=float)
@click.option("--dropout", default=0.0, type=float)
@click.option("--drop-path", default=0.1, type=float)
@click.option("--batch-size", default=None, type=int)
@click.option("--epochs", default=None, type=int)
@click.option("-lr", "--learning-rate", default=None, type=float)
@click.option("--normalization", default=None, type=str)
@click.option("--eval-freq", default=None, type=int)
@click.option("--amp/--no-amp", default=False, is_flag=True)
@click.option("--resume/--no-resume", default=True, is_flag=True)
@click.option("--model_name", default=None, type=str)
@click.option("--attn_type", default=None, type=str)
@click.option("--alpha", default=None, type=float)
@click.option("--eval", default=False, is_flag=True)
@click.option("--eval-c", default=False, is_flag=True)


def main(
    log_dir,
    dataset,
    im_size,
    crop_size,
    window_size,
    window_stride,
    backbone,
    decoder,
    optimizer,
    scheduler,
    weight_decay,
    dropout,
    drop_path,
    batch_size,
    epochs,
    learning_rate,
    normalization,
    eval_freq,
    amp,
    resume,
    model_name,
    attn_type,
    alpha,
    eval,
    eval_c,
):
    # start distributed mode
    ptu.set_gpu_mode(True)
    distributed.init_process()

    torch.manual_seed(0)
    np.random.seed(0)

    # set up configuration
    cfg = config.load_config()
    model_cfg = cfg["model"][backbone]

    dataset_cfg = cfg["dataset"][dataset]
    if "mask_transformer" in decoder:
        decoder_cfg = cfg["decoder"]["mask_transformer"]
    else:
        decoder_cfg = cfg["decoder"][decoder]

    # model config
    if not im_size:
        im_size = dataset_cfg["im_size"]
    if not crop_size:
        crop_size = dataset_cfg.get("crop_size", im_size)
    if not window_size:
        window_size = dataset_cfg.get("window_size", im_size)
    if not window_stride:
        window_stride = dataset_cfg.get("window_stride", im_size)

    model_cfg["image_size"] = (crop_size, crop_size)
    model_cfg["backbone"] = backbone
    model_cfg["dropout"] = dropout
    model_cfg["drop_path_rate"] = drop_path
    decoder_cfg["name"] = decoder
    model_cfg["decoder"] = decoder_cfg

    # dataset config
    world_batch_size = dataset_cfg["batch_size"]
    num_epochs = dataset_cfg["epochs"]
    lr = dataset_cfg["learning_rate"]
    if batch_size:
        world_batch_size = batch_size
    if epochs:
        num_epochs = epochs
    if learning_rate:
        lr = learning_rate
    if eval_freq is None:
        eval_freq = dataset_cfg.get("eval_freq", 1)

    if normalization:
        model_cfg["normalization"] = normalization

    # experiment config
    #batch_size = world_batch_size // ptu.world_size
    batch_size = 1
    variant = dict(
        world_batch_size=world_batch_size,
        version="normal",
        resume=resume,
        dataset_kwargs=dict(
            dataset=dataset,
            image_size=im_size,
            crop_size=crop_size,
            batch_size=batch_size,
            normalization=model_cfg["normalization"],
            split="train",
            num_workers=10,
        ),
        algorithm_kwargs=dict(
            batch_size=batch_size,
            start_epoch=0,
            num_epochs=num_epochs,
            eval_freq=eval_freq,
        ),
        optimizer_kwargs=dict(
            opt=optimizer,
            lr=lr,
            weight_decay=weight_decay,
            momentum=0.9,
            clip_grad=None,
            sched=scheduler,
            epochs=num_epochs,
            min_lr=1e-5,
            poly_power=0.9,
            poly_step_size=1,
        ),
        net_kwargs=model_cfg,
        amp=amp,
        log_dir=log_dir,
        inference_kwargs=dict(
            im_size=im_size,
            window_size=window_size,
            window_stride=window_stride,
        ),
    )

    log_dir = Path(log_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_path = log_dir / "checkpoint.pth"

    # dataset
    dataset_kwargs = variant["dataset_kwargs"]

    train_loader = create_dataset(dataset_kwargs)
    val_kwargs = dataset_kwargs.copy()
    val_kwargs["split"] = "val"
    #val_kwargs['split'] = 'val_brightness'
    val_kwargs["batch_size"] = 1
    val_kwargs["crop"] = False
    val_loader = create_dataset(val_kwargs)
    n_cls = train_loader.unwrapped.n_cls

    # model
    net_kwargs = variant["net_kwargs"]
    net_kwargs["n_cls"] = n_cls

    ###### Choose attention
    net_kwargs['attn_type'] = attn_type
    net_kwargs['decoder']['attn_type'] = attn_type
    net_kwargs['alpha'] = alpha
    net_kwargs['decoder']['alpha'] = alpha

    ######
    model = create_segmenter(net_kwargs)
    if eval:
        from model.factory import load_model
        #model, _ = load_model('checkpoints/deit_elliptical-delta5/checkpoint.pth')
        
    model.to(ptu.device)
    #breakpoint()

    # optimizer
    optimizer_kwargs = variant["optimizer_kwargs"]
    optimizer_kwargs["iter_max"] = len(train_loader) * optimizer_kwargs["epochs"]
    optimizer_kwargs["iter_warmup"] = 0.0
    opt_args = argparse.Namespace()
    opt_vars = vars(opt_args)
    for k, v in optimizer_kwargs.items():
        opt_vars[k] = v
    optimizer = create_optimizer(opt_args, model)
    lr_scheduler = create_scheduler(opt_args, optimizer)
    num_iterations = 0
    amp_autocast = suppress
    loss_scaler = None
    if amp:
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()

    # resume
    if resume and checkpoint_path.exists():
        print(f"Resuming training from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        if loss_scaler and "loss_scaler" in checkpoint:
            loss_scaler.load_state_dict(checkpoint["loss_scaler"])
        lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        variant["algorithm_kwargs"]["start_epoch"] = checkpoint["epoch"] 
    else:
        sync_model(log_dir, model)

    if ptu.distributed:
        model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True)

    # save config
    variant_str = yaml.dump(variant)
    print(f"Configuration:\n{variant_str}")
    variant["net_kwargs"] = net_kwargs
    variant["dataset_kwargs"] = dataset_kwargs
    log_dir.mkdir(parents=True, exist_ok=True)
    with open(log_dir / "variant.yml", "w") as f:
        f.write(variant_str)

    # train
    start_epoch = variant["algorithm_kwargs"]["start_epoch"]
    num_epochs = variant["algorithm_kwargs"]["num_epochs"]
    eval_freq = variant["algorithm_kwargs"]["eval_freq"]

    model_without_ddp = model
    if hasattr(model, "module"):
        model_without_ddp = model.module

    val_seg_gt = val_loader.dataset.get_gt_seg_maps()

    print(f"Train dataset length: {len(train_loader.dataset)}")
    print(f"Val dataset length: {len(val_loader.dataset)}")
    print(f"Encoder parameters: {num_params(model_without_ddp.encoder)}")
    print(f"Decoder parameters: {num_params(model_without_ddp.decoder)}")

    if eval:
        #assert not eval_c
        print("here")
        eval_logger = evaluate(
                model,
                val_loader,
                val_seg_gt,
                window_size,
                window_stride,
                amp_autocast,
            )
        print(f"Stats [EVAL]:", eval_logger, flush=True)
        print("")
        val_stats = {
                    k: meter.global_avg for k, meter in eval_logger.meters.items()
                }
        log_stats = {
                **{f"val_{k}": v for k, v in val_stats.items()}
            }
        with open(log_dir / "log.txt", "a") as f:
            f.write(json.dumps(log_stats) + "\n")
       
    if eval_c:
        from model.factory import load_model
        model, _ = load_model('checkpoints/deit_elliptical-delta5-2/checkpoint-backup.pth')

        model.to(ptu.device)
        if ptu.distributed:
            model = DDP(model, device_ids=[ptu.device], find_unused_parameters=True)
        model_without_ddp = model
        if hasattr(model, "module"):
            model_without_ddp = model.module
        
        model.eval()
        for batch in train_loader:
            im = batch["im"].to(ptu.device)
            with amp_autocast():
                seg_pred = model.forward(im)
        
        print('passed through train data, W_save now populated')
        breakpoint()
        print('activating use of W_save')
        for block in model.encoder.blocks.children():
            block.attn.use_W_save = True
        
        # breakpoint()
        print('evaluating all corruptions')
        corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
                'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
                'speckle_noise', 'gaussian_blur', 'spatter', 'saturate']
        # corruptions = ['elastic_transform', 'pixelate', 'jpeg_compression',
        #         'speckle_noise', 'gaussian_blur', 'spatter', 'saturate']
        cumulative_stats = {'pixel_accuracy': 0.0, 'mean_accuracy': 0.0, 'mean_iou': 0.0}
        for c in corruptions:
            print(f'calling {c}')
            val_kwargs['split'] = f'val_{c}'
            val_loader = create_dataset(val_kwargs)
            print('loader built')

            eval_logger = evaluate(
                model,
                val_loader,
                val_seg_gt,
                window_size,
                window_stride,
                amp_autocast,
            )
            print(f"Stats [EVAL]:", eval_logger, flush=True)
            val_stats = {
                        k: meter.global_avg for k, meter in eval_logger.meters.items()
                    }
            cumulative_stats['pixel_accuracy'] = cumulative_stats["pixel_accuracy"] + val_stats['pixel_accuracy']
            cumulative_stats['mean_accuracy'] = cumulative_stats["mean_accuracy"] + val_stats['mean_accuracy']
            cumulative_stats['mean_iou'] = cumulative_stats["mean_iou"] + val_stats['mean_iou']
            log_stats = {
                    **{f"val_{k}": v for k, v in val_stats.items()}
                }
            with open(log_dir / "log.txt", "a") as f:
                f.write(json.dumps(log_stats) + "\n")
        cumulative_stats['pixel_accuracy'] = cumulative_stats["pixel_accuracy"] / len(corruptions)
        cumulative_stats['mean_accuracy'] = cumulative_stats["mean_accuracy"] / len(corruptions)
        cumulative_stats['mean_iou'] = cumulative_stats["mean_iou"] / len(corruptions)
        for k,v in cumulative_stats.items():
            print(f'{k} | {v}')
        with open(log_dir / 'log.txt', 'a') as f:
            f.write(json.dumps(cumulative_stats) + '\n')
        return

    for epoch in range(start_epoch, num_epochs):
        # train for one epoch
        train_logger = train_one_epoch(
            model,
            train_loader,
            optimizer,
            lr_scheduler,
            epoch,
            amp_autocast,
            loss_scaler,
        )

        # save checkpoint
        if ptu.dist_rank == 0:
            snapshot = dict(
                model=model_without_ddp.state_dict(),
                optimizer=optimizer.state_dict(),
                n_cls=model_without_ddp.n_cls,
                lr_scheduler=lr_scheduler.state_dict(),
            )
            if loss_scaler is not None:
                snapshot["loss_scaler"] = loss_scaler.state_dict()
            snapshot["epoch"] = epoch
            torch.save(snapshot, checkpoint_path)

        # evaluate
        eval_epoch = epoch % eval_freq == 0 or epoch == num_epochs - 1
        if eval_epoch:
            eval_logger = evaluate(
                model,
                val_loader,
                val_seg_gt,
                window_size,
                window_stride,
                amp_autocast,
            )
            print(f"Stats [{epoch}]:", eval_logger, flush=True)
            print("")

        # log stats
        if ptu.dist_rank == 0:
            train_stats = {
                k: meter.global_avg for k, meter in train_logger.meters.items()
            }
            val_stats = {}
            if eval_epoch:
                val_stats = {
                    k: meter.global_avg for k, meter in eval_logger.meters.items()
                }

            log_stats = {
                **{f"train_{k}": v for k, v in train_stats.items()},
                **{f"val_{k}": v for k, v in val_stats.items()},
                "epoch": epoch,
                "num_updates": (epoch + 1) * len(train_loader),
            }

            with open(log_dir / "log.txt", "a") as f:
                f.write(json.dumps(log_stats) + "\n")

    distributed.barrier()
    distributed.destroy_process()
    sys.exit(1)


if __name__ == "__main__":
    main()
