import argparse

import numpy as np
import torch
import wandb
from torch import nn
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
import torchvision.transforms as T

import sys

import models.backbones
from dataset.dataset_util import dreamsim_transform
from dataset.nyud_dataset import NYUv2
from dataset.lf_dataset import LightFieldDataset
from dataset.sunrgbd_dataset import SUNRGBDDataset
from scripts.util import unnorm
from util.gradientloss import GradientLoss
from util.sigloss import SigLoss

sys.path.append('../dinov2')
from PIL import Image
import matplotlib.pyplot as plt
from util.utils import feature_pca
from util.train_utils import seed_worker, seed_everything
import torch.nn.functional as F

def log_image():
    n = 4
    f, ax = plt.subplots(n, 4, figsize=(4 * 3, n * 3))

    ax[0][0].set_title(f'image batch {i}')
    ax[0][1].set_title(f'features')
    ax[0][2].set_title(f'pred')
    ax[0][3].set_title(f'target')

    images_unnorm = unnorm(images) if norm else images

    for j in range(n):
        vals = preds.detach().cpu()[j]
        target_colored = target.detach().cpu()[j]

        ax[j][0].imshow(images_unnorm[j].permute(1, 2, 0).detach().cpu())
        features_pca = feature_pca(features_reshaped[j:j + 1])
        ax[j][1].imshow(features_pca[0])
        ax[j][2].imshow(vals, cmap='inferno')
        ax[j][3].imshow(target_colored, cmap='inferno')

    plt.tight_layout()
    return plt.gcf()

class DepthHead(torch.nn.Module):
    def __init__(self, in_channels, channels=96, conv_cfg=None, act_cfg=dict(type="ReLU"),
                align_corners=False, min_depth=1e-3, max_depth=None, norm_cfg=None, classify=False, n_bins=256,
                bins_strategy="UD", norm_strategy="linear", scale_up=False):
        super(DepthHead, self).__init__()
        self.in_channels = in_channels
        self.channels = channels
        self.conv_cfg = conv_cfg
        self.act_cfg = act_cfg
        self.loss_decode = SigLoss(max_depth=max_depth)
        self.align_corners = align_corners
        self.min_depth = min_depth
        self.max_depth = max_depth
        self.norm_cfg = norm_cfg
        self.classify = classify
        self.n_bins = n_bins
        self.scale_up = scale_up

        if self.classify:
            assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
            assert norm_strategy in ["linear", "softmax",
                                     "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"

            self.bins_strategy = bins_strategy
            self.norm_strategy = norm_strategy
            self.softmax = nn.Softmax(dim=1)
            self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
        else:
            self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)

        self.fp16_enabled = False
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        if self.classify:
            logit = self.conv_depth(x)

            if self.bins_strategy == "UD":
                bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device)
            elif self.bins_strategy == "SID":
                bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=x.device)

            # following Adabins, default linear
            if self.norm_strategy == "linear":
                logit = torch.relu(logit)
                eps = 0.1
                logit = logit + eps
                logit = logit / logit.sum(dim=1, keepdim=True)
            elif self.norm_strategy == "softmax":
                logit = torch.softmax(logit, dim=1)
            elif self.norm_strategy == "sigmoid":
                logit = torch.sigmoid(logit)
                logit = logit / logit.sum(dim=1, keepdim=True)

            output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
        else:
            if self.scale_up:
                output = self.sigmoid(self.conv_depth(x)) * self.max_depth
            else:
                output = self.relu(self.conv_depth(x)) + self.min_depth
        return output


parser = argparse.ArgumentParser()
parser.add_argument('-b', '--backbone', type=str, default='dino_vitb16', help='backbone model')
parser.add_argument('-e', '--epoch', type=int, default=9, help='epoch')
parser.add_argument('-d', '--dataset', type=str, default='nyud', help='dataset')
args = parser.parse_args()

dataset = args.dataset  #'nyud'
model_name = args.backbone #'dino_vitb16'

seed = 1235
seed_everything(seed)

patch_hw = 32

lr = 3e-4
epochs = 10
bs = 128
head = DepthHead(patch_hw * patch_hw, 768, classify=True, min_depth=0.001, max_depth=10).cuda()
optimizer = torch.optim.Adam(head.parameters(), lr=lr)
dino_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_resize = T.Resize((224, 224), interpolation=InterpolationMode.BILINEAR)
target_transform = T.Compose(
    [
        target_resize,
        T.PILToTensor(),
    ]
)

sigloss = SigLoss(valid_mask=True, warm_up=True, max_depth=None)
gradientloss = GradientLoss(valid_mask=True, loss_weight=0.5)
lossfn = lambda pred, gt: sigloss(pred, gt) + 0.5 * gradientloss(pred, gt)
norm = False

if model_name == 'dino_vitb16':
    model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16').cuda()
    model_fn = lambda x: model.get_intermediate_layers(images)[0][:, 1:]
    transform = dino_transform
    norm = True
elif model_name == 'dinov2_vitb14':
    model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
    model_fn = lambda x: model.forward_features(x)['x_norm_patchtokens']
    transform = dino_transform
    norm = True
elif 'dreamsim' in model_name:
    import os
    from dreamsim.dreamsim.model import dreamsim, PerceptualModel
    from dreamsim.dreamsim.feature_extraction.vit_wrapper import ViTConfig, ViTModel
    import json
    from peft import LoraConfig, get_peft_model, PeftModel

    if model_name == 'dreamsim_dinov2_vitb14':
        model_dir = 'lora_single_cat_dinov2_vitb14_n-1_cls_patch_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_16.0_loradropout_0.1/lightning_logs/version_0/checkpoints'
        backbone = 'dinov2_vitb14'
        stride = '14'
        epoch = args.epoch
    elif model_name == 'dreamsim_dino_vitb16':
        model_dir = 'lora_single_cat_dino_vitb16_n-1_cls_patch_lora_lr_0.0003_batchsize_16_wd_0.0_hiddensize_1_margin_0.05_lorar_16_loraalpha_16.0_loradropout_0.1/lightning_logs/version_0/checkpoints'
        backbone = 'dino_vitb16'
        stride = '16'
        epoch = args.epoch

    load_dir = '/home/fus/repos/repalignment/dreamsim/output/new_backbones2/'
    model_dir = os.path.join(load_dir, model_dir)
    device = 'cuda:0'

    ours_model = PerceptualModel(backbone, device=device, load_dir='/home/fus/repos/repalignment/dreamsim/models/',
                                 normalize_embeds=False, stride=stride, feat_type='cls_patch', lora=True)

    load_dir = os.path.join(model_dir, f'epoch_{epoch}_{backbone}')
    with open(os.path.join(load_dir, 'adapter_config.json'), 'r') as f:
        dreamsim_config = json.load(f)
    lora_config = LoraConfig(**dreamsim_config)
    print(lora_config)
    for extractor in ours_model.extractor_list:
        model = get_peft_model(ViTModel(extractor.model, ViTConfig()), lora_config).to(device)
        extractor.model = model
    for extractor in ours_model.extractor_list:
        extractor.model = PeftModel.from_pretrained(extractor.model, load_dir).to(device)
        extractor.model.eval().requires_grad_(False)

    ours_model.eval().requires_grad_(False)
    ours_model = ours_model.to(device)
    transform = dreamsim_transform

    model_fn = lambda x: ours_model.embed(x)[:, 1:, :]


if dataset == 'nyud':
    train_dset = NYUv2(root='/datasets/nyuv2_2024-01-10_1001/data/nyu', split='train', transform=transform, target_transform=target_transform)
    val_dset = NYUv2(root='/datasets/nyuv2_2024-01-10_1001/data/nyu', split='test', transform=transform, target_transform=target_transform)
elif dataset == 'lf':
    train_dset = NYUv2(root='/datasets/nyuv2_2024-01-10_1001/data/nyu', split='train', transform=transform, target_transform=target_transform)
    val_dset = LightFieldDataset(root='/scratch/one_month/2024_05/fus/lf/full_data', split='val', transform=transform, target_transform=target_transform)
elif dataset == 'sunrgbd':
    train_dset = SUNRGBDDataset(root='/scratch/one_month/2024_05/fus/sunrgbd/SUNRGBD', split='train', transform=transform, target_transform=target_transform)
    val_dset = SUNRGBDDataset(root='/scratch/one_month/2024_05/fus/sunrgbd/SUNRGBDv2Test', split='val', transform=transform, target_transform=target_transform)

nw = 12
g = torch.Generator()
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=bs, shuffle=True, worker_init_fn=seed_worker, generator=g, num_workers=nw)
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=64, shuffle=False, num_workers=nw)

exp_name = f'depth_{model_name}_{dataset}_lr_{lr}_bs_{bs}_epochs_{epochs}'
exp_name = f'e{epoch}_' + exp_name if 'dreamsim' in model_name else exp_name
debug = False

if not debug:
    wandb.init(
        project='rep_segment',
        name=exp_name,
        config={
            "learning_rate": lr,
            "epochs": epochs,
        }
    )

for epoch_i in range(epochs):
    print(f'Training Epoch {epoch_i}')
    head.train()
    for i, (images, masks) in tqdm(enumerate(train_loader), total=len(train_loader)):
        images = images.cuda()
        masks = masks.cuda().float()

        with torch.no_grad():
            features = model_fn(images)
        b, hw, c = features.shape
        s = int(hw ** 0.5)
        features_reshaped = features.permute(0, 2, 1).reshape(b, c, s, s)
        features_reshaped = F.interpolate(features_reshaped, scale_factor=4, mode='bilinear')
        preds = head(features_reshaped).squeeze()
        target = torch.nn.functional.interpolate(masks, preds.shape[-2:], mode='bilinear').cuda().squeeze()

        loss = lossfn(preds, target)

        if i % 10 == 0 and not debug:
            wandb.log({'train/loss': loss})

        # if i % 40 == 0 and not debug:
        #     f = log_image()
        #     wandb.log({'train/depth': wandb.Image(f)})
        loss.backward()
        optimizer.step()
    print(f'Val Epoch {epoch_i}')
    head.eval()
    with torch.no_grad():
        val_loss = 0
        val_acc = 0
        rmse = 0
        abs_rel_error = 0
        delta1 = 0
        delta2 = 0
        delta3 = 0
        log10 = 0
        val_total = 0
        for i, (images, masks) in tqdm(enumerate(val_loader), total=len(val_loader)):
            images = images.cuda()
            masks = masks.cuda().float()

            with torch.no_grad():
                features = model_fn(images)
            b, hw, c = features.shape
            s = int(hw ** 0.5)
            features_reshaped = features.permute(0, 2, 1).reshape(b, c, s, s)
            features_reshaped = F.interpolate(features_reshaped, scale_factor=4, mode='bilinear')
            preds = head(features_reshaped).squeeze()
            target = torch.nn.functional.interpolate(masks, preds.shape[-2:], mode='bilinear').cuda().squeeze()
            loss = lossfn(preds, target)

            if i == 0 and not debug:
                f = log_image()
                wandb.log({'val/depth': wandb.Image(f)})

            valid_mask = target > 0
            valid_mask = torch.logical_and(target > 0, target <= 10)
            preds = preds[valid_mask]
            target = target[valid_mask]

            rmse += torch.sqrt(torch.mean((preds - target) ** 2)) * len(images)
            abs_rel_error += torch.mean(torch.abs(preds - target) / target) * len(images)

            delta1 += torch.mean((torch.max(preds / target, target / preds) < 1.25).float()) * len(images)
            delta2 += torch.mean((torch.max(preds / target, target / preds) < 1.25 ** 2).float()) * len(images)
            delta3 += torch.mean((torch.max(preds / target, target / preds) < 1.25 ** 3).float()) * len(images)

            log10 += torch.mean(torch.abs(torch.log10(preds) - torch.log10(target))) * len(images)

            val_total += len(images)

            val_loss += loss
        val_loss /= val_total
        rmse /= val_total
        abs_rel_error /= val_total
        delta1 /= val_total
        delta2 /= val_total
        delta3 /= val_total
        log10 /= val_total

        if not debug:
            wandb.log({'val/loss': val_loss,
                       'val/rmse': rmse,
                       'val/abs_rel_error': abs_rel_error,
                       'val/delta1': delta1,
                       'val/delta2': delta2,
                       'val/delta3': delta3,
                       'val/log10': log10})
        plt.close()
