import click
import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
import random
import torchvision.models.resnet as resnet
import torchvision.transforms as T
import sacred
import pandas as pd
from datetime import datetime
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms.functional import InterpolationMode
from typing import List, Any, Tuple

from data.VOCdevkit.vocdata import VOCDataModule
from data.coco.coco_data_module import CocoDataModule
from experiments.utils import PredsmIoU, get_backbone_weights
from src.models.resnet import ResnetDilated
from src.models.vit import vit_small, vit_base, vit_large
from src.models.vit_v2 import vit_small as vit_small_v2, vit_base as vit_base_v2, vit_large as vit_large_v2
from src.linear_finetuning_transforms import Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor, SepTransforms

# from src.models.vit_clip_exp import vit_base as vit_clip_base
from data.cityscapes.cityscapes_data import CityscapesDataModule
from data.ade20k.ade20kdata import Ade20kDataModule

ex = sacred.experiment.Experiment()
api_key = "<YOUR API KEY HERE>"

@click.command()
@click.option("--config_path", type=str)
@click.option("--ckpt_path", type=str, default=None)
@click.option('--method', type=str, default=None)
@click.option('--arch', type=str, default=None)
def entry(config_path, ckpt_path, method, arch):
    if config_path is not None:
        ex.add_config(os.path.join(os.path.abspath(os.path.dirname(__file__)), config_path))
    else:
        ex.add_config(os.path.join(os.path.abspath(os.path.dirname(__file__)), "finetune_dev.yml"))
    time = datetime.now().strftime("%Y%m%d-%H%M%S")
    ex_name = f"linear-finetune-{time}"
    checkpoint_dir = os.path.join(ex.configurations[0]._conf["train"]["ckpt_dir"], ex_name)
    ex.observers.append(sacred.observers.FileStorageObserver(checkpoint_dir))
    params = {'seed': 400}
    if method is not None:
        params['train.method'] = method
    if ckpt_path is not None:
        params['train.ckpt_path'] = ckpt_path
    if arch is not None:
        params['train.arch'] = arch

    ex.run(config_updates=params, options={'--name': ex_name})

@ex.main
@ex.capture
def linear_finetune(_config, _run):
    # Init logger
    neptune_logger = NeptuneLogger(
        api_key=api_key,
        mode="offline" ,
        project="<Your Project Name>",
        name=_run.experiment_info["name"],
        tags=_config["tags"].split(','),
    )
    params=pd.json_normalize(_config).to_dict(orient='records')[0] 
    neptune_logger.experiment["parameters"]=params
    
    print("Config:")
    print(_config)
    data_config = _config["data"]
    train_config = _config["train"]
    seed_everything(_config["seed"])
    input_size = data_config["size_crops"]

    # Init transforms and train data
    train_transforms = Compose([
        RandomResizedCrop(size=input_size, scale=(0.8, 1.)),
        RandomHorizontalFlip(p=0.5),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_image_transforms = T.Compose([T.Resize((input_size, input_size)),
                                      T.ToTensor(),
                                      T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    val_target_transforms = T.Compose([T.Resize((input_size, input_size), interpolation=InterpolationMode.NEAREST),
                                       T.ToTensor()])

    data_dir = data_config["data_dir"]
    dataset_name = data_config["dataset_name"]
    if dataset_name == "voc":
        num_classes = 21
        ignore_index = 255
        data_module = VOCDataModule(batch_size=train_config["batch_size"],
                                    return_masks=True,
                                    num_workers=_config["num_workers"],
                                    train_split="trainaug",
                                    val_split="val",
                                    data_dir=data_dir,
                                    train_image_transform=train_transforms,
                                    drop_last=True,
                                    val_image_transform=val_image_transforms,
                                    val_target_transform=val_target_transforms)
    elif "coco" in dataset_name:
        assert len(dataset_name.split("-")) == 2
        mask_type = dataset_name.split("-")[-1]
        assert mask_type in ["thing", "stuff"]
        if mask_type == "thing":
            num_classes = 12
        else:
            num_classes = 15
        ignore_index = 255
        file_list = os.listdir(os.path.join(data_dir, "images", "train2017"))
        file_list_val = os.listdir(os.path.join(data_dir, "images", "val2017"))
        random.shuffle(file_list_val)
        # sample 10% of train images
        random.shuffle(file_list)
        file_list = file_list[:int(len(file_list)*0.1)]
        print(f"sampled {len(file_list)} COCO images for training")

        data_module = CocoDataModule(batch_size=train_config["batch_size"],
                                     num_workers=_config["num_workers"],
                                     file_list=file_list,
                                     data_dir=data_dir,
                                     file_list_val=file_list_val,
                                     mask_type=mask_type,
                                     train_transforms=train_transforms,
                                     val_transforms=val_image_transforms,
                                     val_target_transforms=val_target_transforms)
    elif dataset_name == "ade20k":
        num_classes = 151
        ignore_index = 0
        val_transforms = SepTransforms(val_image_transforms, val_target_transforms)
        data_module = Ade20kDataModule(data_dir,
                                        train_transforms=train_transforms,
                                        val_transforms=val_transforms,
                                        shuffle=False,
                                        num_workers=_config["num_workers"],
                                        batch_size=train_config["batch_size"])
    elif dataset_name == "cityscapes":
        num_classes = 19
        ignore_index = 255
        val_transforms = SepTransforms(val_image_transforms, val_target_transforms)
        data_module = CityscapesDataModule(root=data_dir,
                                           train_transforms=train_transforms,
                                           val_transforms=val_transforms,
                                           shuffle=True,
                                           num_workers=_config["num_workers"],
                                           batch_size=train_config["batch_size"])
    else:
        raise ValueError(f"{dataset_name} not supported")

    # Init Method
    arch = train_config["arch"]
    patch_size = train_config["patch_size"]
    restart = train_config["restart"]
    val_iters = train_config["val_iters"]
    method = train_config["method"]
    spatial_res = input_size / patch_size
    decay_rate = train_config.get("decay_rate")
    assert spatial_res.is_integer()
    model = LinearFinetune(
        patch_size=patch_size,
        head_type=train_config.get("head_type"),
        arch=arch,
        arch_version=train_config.get("arch_version"),
        num_classes=num_classes,
        lr=train_config["lr"],
        input_size=input_size,
        spatial_res=int(spatial_res),
        val_iters=val_iters,
        decay_rate=decay_rate if decay_rate is not None else 0.1,
        drop_at=train_config["drop_at"],
        ignore_index=ignore_index,
        num_register_tokens=train_config["num_register_tokens"] if "num_register_tokens" in train_config else 0,
    )

    # Optionally load weights
    if not restart:
        weights = get_backbone_weights(arch, method, patch_size=patch_size, ckpt_path=train_config.get("ckpt_path"))
        msg = model.load_state_dict(weights, strict=False)
        print(msg)

    # Init checkpoint callback storing top 3 heads
    checkpoint_dir = os.path.join(train_config["ckpt_dir"], _run.experiment_info["name"])
    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        monitor='miou_val',
        filename='ckp-{epoch:02d}-{miou_val:.4f}',
        save_top_k=3,
        mode='max',
        verbose=True,
    )

    # Init trainer and start training head
    trainer = Trainer(
        num_sanity_val_steps=0,
        logger=neptune_logger,
        max_epochs=train_config["max_epochs"],
        devices=_config["gpus"],
        accelerator='cuda', 
        fast_dev_run=train_config["fast_dev_run"],
        log_every_n_steps=50,
        benchmark=True,
        deterministic=False,
        detect_anomaly=False,
        callbacks=[checkpoint_callback]
    )
    trainer.fit(model, datamodule=data_module, ckpt_path=train_config["ckpt_path"] if restart else None)


class LinearFinetune(pl.LightningModule):

    def __init__(self, patch_size: int, num_classes: int, lr: float, input_size: int, spatial_res: int, val_iters: int,
                 drop_at: int, arch: str, arch_version:str=None, head_type: str = None, decay_rate: float = 0.1, ignore_index: int = 255,
                 num_register_tokens=0
                 ):
        super().__init__()
        if type(num_register_tokens) == tuple or type(num_register_tokens) == list:
            num_register_tokens = num_register_tokens[0]
        self.save_hyperparameters()
        if 'vit' in arch:
            # Init Model
            if arch == 'vit-small':
                if arch_version == 'v2':
                    model_func = vit_small_v2
                else:
                    model_func = vit_small
            elif arch == 'vit-base':
                if arch_version == 'v2':
                    model_func = vit_base_v2
                else:
                    model_func = vit_base
            elif arch == 'vit-large':
                if arch_version == 'v2':
                    model_func = vit_large_v2
                else:                
                    model_func = vit_large
            extra_args = {}
            if arch_version == 'v2':
                extra_args['num_register_tokens'] = num_register_tokens
            self.model = model_func(patch_size=patch_size, **extra_args)
        elif arch=='resnet50':
            backbone = resnet.__dict__[arch](pretrained=False)
            self.model = ResnetDilated(backbone)
        
        self.finetune_head = nn.Conv2d(self.model.embed_dim, num_classes, 1)

        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
        self.miou_metric = PredsmIoU(num_classes, num_classes)
        self.num_classes = num_classes
        self.lr = lr
        self.val_iters = val_iters
        self.input_size = input_size
        self.spatial_res = spatial_res
        self.drop_at = drop_at
        self.arch = arch
        self.ignore_index = ignore_index
        self.decay_rate = decay_rate
        self.train_mask_size = 100
        self.val_mask_size = 100

    def on_after_backward(self):
        # Freeze all layers of backbone
        for param in self.model.parameters():
            param.requires_grad = False

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.finetune_head.parameters(), weight_decay=0.0001,
                                    momentum=0.9, lr=self.lr)
        scheduler = StepLR(optimizer, gamma=self.decay_rate, step_size=self.drop_at)
        return [optimizer], [scheduler]

    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        imgs, masks = batch
        bs = imgs.size(0)
        res = imgs.size(3)
        assert res == self.input_size
        self.model.eval()

        with torch.no_grad():
            tokens = self.model.forward_backbone(imgs)
            if 'vit' in self.arch:
                tokens = tokens[:, 1:].reshape(bs, self.spatial_res, self.spatial_res, self.model.embed_dim).\
                    permute(0, 3, 1, 2)
            tokens = nn.functional.interpolate(tokens, size=(self.train_mask_size, self.train_mask_size),
                                               mode='bilinear')
        mask_preds = self.finetune_head(tokens)

        masks *= 255
        if self.train_mask_size != self.input_size:
            with torch.no_grad():
                masks = nn.functional.interpolate(masks, size=(self.train_mask_size, self.train_mask_size),
                                                  mode='nearest')

        loss = self.criterion(mask_preds, masks.long().squeeze())

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
        if self.val_iters is None or batch_idx < self.val_iters:
            with torch.no_grad():
                imgs, masks = batch
                bs = imgs.size(0)
                tokens = self.model.forward_backbone(imgs)
                if 'vit' in self.arch:
                    tokens = tokens[:, 1:].reshape(bs, self.spatial_res, self.spatial_res, self.model.embed_dim).\
                        permute(0, 3, 1, 2)
                tokens = nn.functional.interpolate(tokens, size=(self.val_mask_size, self.val_mask_size),
                                                   mode='bilinear')
                mask_preds = self.finetune_head(tokens)

                # downsample masks and preds
                gt = masks * 255
                gt = nn.functional.interpolate(gt, size=(self.val_mask_size, self.val_mask_size), mode='nearest')
                valid = (gt != self.ignore_index) # mask to remove object boundary class
                mask_preds = torch.argmax(mask_preds, dim=1).unsqueeze(1)

                # update metric
                self.miou_metric.update(gt[valid], mask_preds[valid])

    def on_validation_epoch_end(self) -> None:
        miou = self.miou_metric.compute(True, many_to_one=False, linear_probe=True)[0]
        self.miou_metric.reset()
        print(miou)
        self.log('miou_val', round(miou, 6))


if __name__ == "__main__":
    print('+'*50)
    entry()
    print('-'*50)
