import json

import torch.nn as nn
import torch
from transformers import ViTImageProcessor, ViTForImageClassification, ViTConfig
from timm.models.layers import trunc_normal_
from omegaconf import OmegaConf
from .ldm.util import instantiate_from_config
import torch.nn.functional as F
from einops import repeat
import lightning as L
from .utils import pad, unpad, silog
from .optimizer import get_optimizer
from .metrics import compute_metrics
from .utils import eigen_crop, garg_crop, custom_crop, no_crop


NUM_DECONV = 3
NUM_FILTERS = [32, 32, 32]
DECONV_KERNELS = [2, 2, 2]
VIT_MODEL = 'google/vit-base-patch16-224'


def pad_to_make_square(x):
    y = 255*((x+1)/2)
    y = torch.permute(y, (0,2,3,1))
    bs, _, h, w = x.shape
    if w>h:
        patch = torch.zeros(bs, w-h, w, 3).to(x.device)
        y = torch.cat([y, patch], axis=1)
    else:
        patch = torch.zeros(bs, h, h-w, 3).to(x.device)
        y = torch.cat([y, patch], axis=2)
    return y.to(torch.int)


class Args:
    def __init__(self):
        with open("train_config.json", "r") as f:
            config = json.load(f) 
        for n, v in config.items():
           setattr(self, n, v)


class EmbeddingAdapter(nn.Module):
    def __init__(self, emb_dim=768):
        super().__init__()
        
        self.fc = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim)
        )

    def forward(self, texts, gamma):
        emb_transformed = self.fc(texts)
        texts = texts + gamma * emb_transformed
        texts = repeat(texts, 'n c -> n b c', b=1)
        return texts

class EcoDepthEncoder(nn.Module):
    def __init__(
        self, 
        out_dim=1024, 
        ldm_prior=[320, 640, 1280+1280], 
        sd_path=None, 
        emb_dim=768, 
        args=None,
        train_from_scratch=False,
    ):
        super().__init__()

        self.args = args

        self.layer1 = nn.Sequential(
            nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
            nn.GroupNorm(16, ldm_prior[0]),
            nn.ReLU(),
            nn.Conv2d(ldm_prior[0], ldm_prior[0], 3, stride=2, padding=1),
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(ldm_prior[1], ldm_prior[1], 3, stride=2, padding=1),
        )

        self.out_layer = nn.Sequential(
            nn.Conv2d(sum(ldm_prior), out_dim, 1),
            nn.GroupNorm(16, out_dim),
            nn.ReLU(),
        )
        
        if train_from_scratch:
            self.apply(self._init_weights)
        
        self.cide_module = CIDE(args, emb_dim, train_from_scratch)
        
        self.config = OmegaConf.load('./v1-inference.yaml')
        unet_config = self.config.model.params.unet_config
        first_stage_config = self.config.model.params.first_stage_config
        
        if train_from_scratch:
            if sd_path is None:
                sd_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
            # unet_config.params.ckpt_path = sd_path
        
        self.unet = instantiate_from_config(unet_config)
        self.encoder_vq = instantiate_from_config(first_stage_config)
        del self.encoder_vq.decoder
        del self.unet.out

        for param in self.encoder_vq.parameters():
            param.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        with torch.no_grad():
            # convert the input image to latent space and scale.
            latents = self.encoder_vq.encode(x).mode().detach() * self.config.model.params.scale_factor

        conditioning_scene_embedding = self.cide_module(x)

        t = torch.ones((x.shape[0],), device=x.device).long()
        outs = self.unet(latents, t, context=conditioning_scene_embedding)

        feats = [outs[0], outs[1], torch.cat([outs[2], F.interpolate(outs[3], scale_factor=2)], dim=1)]
        x = torch.cat([self.layer1(feats[0]), self.layer2(feats[1]), feats[2]], dim=1)
        return self.out_layer(x)

class CIDE(nn.Module):
    def __init__(self, args, emb_dim, train_from_scratch):
        super().__init__()
        self.args = args
        self.vit_processor = ViTImageProcessor.from_pretrained(VIT_MODEL, resume_download=True)
        if train_from_scratch:
            vit_config = ViTConfig(num_labels=1000)
            self.vit_model = ViTForImageClassification(vit_config)
        else:
            self.vit_model = ViTForImageClassification.from_pretrained(VIT_MODEL, resume_download=True)
        for param in self.vit_model.parameters():
            param.requires_grad = False
        
        self.fc = nn.Sequential(
            nn.Linear(1000, 400),
            nn.GELU(),
            nn.Linear(400, args.no_of_classes)
        )
        self.dim = emb_dim
        self.m = nn.Softmax(dim=1)
        
        self.embeddings = nn.Parameter(torch.randn(self.args.no_of_classes, self.dim))
        self.embedding_adapter = EmbeddingAdapter(emb_dim=self.dim)
        
        self.gamma = nn.Parameter(torch.ones(self.dim) * 1e-4)
    
    def forward(self, x):
        y = pad_to_make_square(x)
        # use torch.no_grad() to prevent gradient flow through the ViT since it is kept frozen
        with torch.no_grad():
            inputs = self.vit_processor(images=y, return_tensors="pt").to(x.device)
            vit_outputs = self.vit_model(**inputs)
            vit_logits = vit_outputs.logits
            
        class_probs = self.fc(vit_logits)
        class_probs = self.m(class_probs)
        
        class_embeddings = class_probs @ self.embeddings
        conditioning_scene_embedding = self.embedding_adapter(class_embeddings, self.gamma) 
        
        return conditioning_scene_embedding


class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, args):
        super().__init__()
        self.deconv = NUM_DECONV
        self.in_channels = in_channels
        self.args = args    
        self.deconv_layers = self._make_deconv_layer(
            NUM_DECONV,
            NUM_FILTERS,
            DECONV_KERNELS,
        )
    
        conv_layers = []
        conv_layers.append(
            nn.Conv2d(
                in_channels=NUM_FILTERS[-1],
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1))
        conv_layers.append(nn.BatchNorm2d(out_channels))
        conv_layers.append(nn.ReLU(inplace=True))
        self.conv_layers = nn.Sequential(*conv_layers)
        
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, conv_feats):
        out = self.deconv_layers(conv_feats[0])
        out = self.conv_layers(out)

        out = self.up(out)
        out = self.up(out)

        return out

    def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
        """Make deconv layers."""
        
        layers = []
        in_planes = self.in_channels
        for i in range(num_layers):
            kernel, padding, output_padding = self._get_deconv_cfg(num_kernels[i])

            planes = num_filters[i]
            layers.append(
                nn.ConvTranspose2d(
                    in_channels=in_planes,
                    out_channels=planes,
                    kernel_size=kernel,
                    stride=2,
                    padding=padding,
                    output_padding=output_padding,
                    bias=False))
            layers.append(nn.BatchNorm2d(planes))
            layers.append(nn.ReLU(inplace=True))
            in_planes = planes

        return nn.Sequential(*layers)

    def _get_deconv_cfg(self, deconv_kernel):
        """Get configurations for deconv layers."""
        if deconv_kernel == 4:
            padding = 1
            output_padding = 0
        elif deconv_kernel == 3:
            padding = 1
            output_padding = 1
        elif deconv_kernel == 2:
            padding = 0
            output_padding = 0
        else:
            raise ValueError(f'Not supported num_kernels ({deconv_kernel}).')

        return deconv_kernel, padding, output_padding

    def init_weights(self):
        """Initialize model weights."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


class EcoDepth(L.LightningModule):
    def __init__(self, args=None):
        super().__init__()
        args = Args()
        self.max_depth = args.max_depth
        self.args = args
        embed_dim = 192
        channels_in = embed_dim * 8
        channels_out = embed_dim

        self.encoder = EcoDepthEncoder(out_dim=channels_in, args = args, train_from_scratch=args.train_from_scratch)
        self.decoder = Decoder(channels_in, channels_out, args)
        
        if args.eval_crop == "eigen":
            self.eval_crop = eigen_crop
        elif args.eval_crop == "garg":
            self.eval_crop = garg_crop
        elif args.eval_crop == "custom":
            self.eval_crop = custom_crop
        else:
            self.eval_crop = no_crop
        
        # Only support finetuning for now
        assert not args.train_from_scratch
        
        if args.train_from_scratch:
            self.decoder.init_weights()

        self.last_layer_depth = nn.Sequential(
            nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=False),
            nn.Conv2d(channels_out, 1, kernel_size=3, stride=1, padding=1),
        )
        
        for m in self.last_layer_depth.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # x must be a pytorch tensor of shape (bs, 3, h, w)
        # and the each value ranges between [0, 1]
        _, _, h, _ = x.shape
        x = x*2.0 - 1.0  # normalize to [-1, 1]
        
        x, padding = pad(x, 64)
        conv_feats = self.encoder(x)
        out = self.decoder([conv_feats])
        out = unpad(out, padding)
        out_depth = self.last_layer_depth(out)            
        pred = torch.sigmoid(out_depth) * self.max_depth
        # pred is a pt of shape (bs, 1, h, w)
        # where each value ranges between [0, self.max_depth]
        return pred
    
    def training_step(self, batch, batch_idx):
        image, depth = batch["image"], batch["depth"]
        pred = self(image)
        loss = silog(pred, depth)
        return loss
    
    def _shared_eval_step(self, batch, batch_idx, prefix):
        image, depth = batch["image"], batch["depth"]
        depth = self.eval_crop(depth)
        image_concat = torch.cat([image, image.flip(-1)])
        pred_concat = self(image_concat)
        pred = ((pred_concat[0] + pred_concat[1].flip(-1))/2).unsqueeze(0)
        loss = silog(pred, depth)
        metrics = compute_metrics(pred, depth, self.args)
        self.log(f"{prefix}_loss", loss)
        self.log_dict(metrics)
        return loss, metrics
        
    def validation_step(self, batch, batch_idx):
        return self._shared_eval_step(batch, batch_idx, "val")
    
    def test_step(self, batch, batch_idx):
        return self._shared_eval_step(batch, batch_idx, "test")
        
    def configure_optimizers(self):
        optimizer = get_optimizer(self, self.args)
        return optimizer
