import sys
import os
sys.path.insert(0, os.path.abspath('../'))

import torch
import torchvision.transforms as transforms

import numpy as np
import copy
import pickle

from functools import partial

from ZSSGAN.model.sg2_model import Generator, Discriminator

from glob import glob
import os.path as osp
import ipdb
from PIL import Image

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
try:
    from torchvision.transforms import InterpolationMode

    def _pil_interp(method):
        if method == 'bicubic':
            return InterpolationMode.BICUBIC
        elif method == 'lanczos':
            return InterpolationMode.LANCZOS
        elif method == 'hamming':
            return InterpolationMode.HAMMING
        else:
            # default bilinear, do we want to allow nearest?
            return InterpolationMode.BILINEAR

    import timm.data.transforms as timm_transforms

    timm_transforms._pil_interp = _pil_interp
except:
    from timm.data.transforms import _pil_interp

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

def proj_subspace(query, points, ndim=None):
    # query: 18 x N x 512
    # points: 18 x 512 x K_shot
    _, d, k_shot = points.shape
    if ndim is None or ndim >= k_shot:
        ndim = k_shot - 1

    mu = points.mean(-1, keepdim=True)
    points = points - mu
    uu, _, _ = torch.svd(points.double())
    uu = uu.float()
    subspace = uu[:, :, :ndim].transpose(1, 2)
    projection = subspace.transpose(1, 2).matmul(subspace.matmul(query.transpose(1, 2) - mu)) + mu
    dist = torch.sum((query - projection.transpose(1, 2))**2, dim=-1)

    logits = dist / d

    return logits.mean(), projection.transpose(1, 2)

class SG2Generator(torch.nn.Module):
    def __init__(self, checkpoint_path, latent_size=512, map_layers=8, img_size=256, channel_multiplier=2, device='cuda:0'):
        super(SG2Generator, self).__init__()

        self.generator = Generator(
            img_size, latent_size, map_layers, channel_multiplier=channel_multiplier
        ).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        state_dict = checkpoint['g_ema']
        self.generator.load_state_dict(state_dict, strict=False)

        with torch.no_grad():
            self.mean_latent = self.generator.mean_latent(4096)

    def get_all_layers(self):
        return list(self.generator.children())

    def get_training_layers(self, phase):

        if phase == 'texture':
            # learned constant + first convolution + layers 3-10
            return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][2:10])   
        if phase == 'shape':
            # layers 1-2
             return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][0:2])
        if phase == 'no_fine':
            # const + layers 1-10
             return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][:10])
        if phase == 'shape_expanded':
            # const + layers 1-10
             return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][0:3])
        if phase == 'all':
            # everything, including mapping and ToRGB
            return self.get_all_layers() 
        if phase == 'map':
            print(list(self.get_all_layers())[0])
            return list(self.get_all_layers())[0]
        else: 
            # everything except mapping and ToRGB
            return list(self.get_all_layers())[1:3] + list(self.get_all_layers()[4][:])  

    def trainable_params(self):
        params = []
        for layer in self.get_training_layers():
            params.extend(layer.parameters())

        return params

    def freeze_layers(self, layer_list=None):
        '''
        Disable training for all layers in list.
        '''
        if layer_list is None:
            self.freeze_layers(self.get_all_layers())
        else:
            for layer in layer_list:
                requires_grad(layer, False)

    def unfreeze_layers(self, layer_list=None):
        '''
        Enable training for all layers in list.
        '''
        if layer_list is None:
            self.unfreeze_layers(self.get_all_layers())
        else:
            for layer in layer_list:
                requires_grad(layer, True)

    def style(self, styles):
        '''
        Convert z codes to w codes.
        '''
        styles = [self.generator.style(s) for s in styles]
        return styles

    def get_s_code(self, styles, input_is_latent=False):
        return self.generator.get_s_code(styles, input_is_latent)

    def modulation_layers(self):
        return self.generator.modulation_layers

    #TODO Maybe convert to kwargs
    def forward(self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        input_is_s_code=False,
        noise=None,
        randomize_noise=True):
        return self.generator(styles, return_latents=return_latents, truncation=truncation, truncation_latent=self.mean_latent, noise=noise, randomize_noise=randomize_noise, input_is_latent=input_is_latent, input_is_s_code=input_is_s_code)

class SG2Discriminator(torch.nn.Module):
    def __init__(self, checkpoint_path, img_size=256, channel_multiplier=2, device='cuda:0'):
        super(SG2Discriminator, self).__init__()

        self.discriminator = Discriminator(
            img_size, channel_multiplier=channel_multiplier
        ).to(device)

        checkpoint = torch.load(checkpoint_path, map_location=device)

        self.discriminator.load_state_dict(checkpoint["d"], strict=True)

    def get_all_layers(self):
        return list(self.discriminator.children())

    def get_training_layers(self):
        return self.get_all_layers() 

    def freeze_layers(self, layer_list=None):
        '''
        Disable training for all layers in list.
        '''
        if layer_list is None:
            self.freeze_layers(self.get_all_layers())
        else:
            for layer in layer_list:
                requires_grad(layer, False)

    def unfreeze_layers(self, layer_list=None):
        '''
        Enable training for all layers in list.
        '''
        if layer_list is None:
            self.unfreeze_layers(self.get_all_layers())
        else:
            for layer in layer_list:
                requires_grad(layer, True)

    def forward(self, images):
        return self.discriminator(images)

class ZSSGAN(torch.nn.Module):
    def __init__(self, args):
        super(ZSSGAN, self).__init__()

        self.args = args

        self.device = 'cuda:0'

        self.generator_frozen = SG2Generator(args.frozen_gen_ckpt, img_size=args.size).to(self.device)
        self.generator_trainable = SG2Generator(args.train_gen_ckpt, img_size=args.size).to(self.device)

        self.generator_frozen.freeze_layers()
        self.generator_frozen.eval()
        
        self.generator_trainable.freeze_layers()
        self.generator_trainable.unfreeze_layers(self.generator_trainable.get_training_layers(args.phase))
        self.generator_trainable.train()


        self.set_transform()
        if self.args.swin_config is not None:
            self.set_model_swin()
            self.swin_encodings = self.set_target(self.swin_model)
        if self.args.dino_config is not None:
            self.set_model_dino()
            self.dino_encodings = self.set_target(self.dino_model)

    def set_transform(self):
        self.preprocess_img = transforms.Compose([
                                    transforms.Resize([224, 224], interpolation=_pil_interp('bicubic')),
                                    # transforms.CenterCrop([224, 224]),
                                    _convert_image_to_rgb,
                                    transforms.ToTensor(),
                                    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ])
        self.preprocess_gen = transforms.Compose([
                                    transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
                                    transforms.Resize([224, 224], interpolation=_pil_interp('bicubic')),
                                    # transforms.CenterCrop([224, 224]),
                                    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
        ])

    def set_model_swin(self):
        from Swin.models import build_model
        from Swin.config import get_config
        
        self.args.cfg = self.args.swin_config
        config = get_config(self.args)
        self.swin_model = build_model(config).to(self.device) 
        checkpoint = torch.load(self.args.swin_ckpt, map_location='cpu')
        msg = self.swin_model.load_state_dict(checkpoint['model'], strict=False)
        print(msg)

    def set_model_dino(self):
        from omegaconf import OmegaConf
        from dinov2.configs import dinov2_default_config
        default_cfg = OmegaConf.create(dinov2_default_config)
        cfg = OmegaConf.load(self.args.dino_config)
        config = OmegaConf.merge(default_cfg, cfg)

        from dinov2.models import build_model_from_cfg
        import dinov2.utils.utils as dinov2_utils
        self.dino_model, _ = build_model_from_cfg(config, only_teacher=True)
        dinov2_utils.load_pretrained_weights(self.dino_model, self.args.dino_ckpt, "teacher")
        self.dino_model = self.dino_model.to(self.device)

    def set_target(self, model):
        with torch.no_grad():
            target_encodings = []
            for i, arr in enumerate(self.args.target_img_list):
                target_encodings.append([])
                for target_img in arr:
                    preprocessed = self.preprocess_img(Image.open(target_img)).unsqueeze(0).to(self.device)
                    encoding = model.get_image_features(preprocessed)

                    target_encodings[i].append(encoding)
            target_encodings = [torch.cat(encoding, 0) for encoding in target_encodings]
        return target_encodings

    def subspace_loss(self, mix_img, src_img, encoder, target_encodings):
        with torch.no_grad():
            src_feat = encoder.get_image_features_norm(self.preprocess_gen(src_img))
        mix_feat = encoder.get_image_features_norm(self.preprocess_gen(mix_img))
        edit_direction = mix_feat - src_feat
        if edit_direction.sum() == 0:
            target_encoding = encoder.get_image_features_norm(self.preprocess_gen(mix_img + 1e-6))
            edit_direction = (target_encoding - src_feat)
        edit_direction = edit_direction / edit_direction.norm(dim=-1, keepdim=True)

        dist = 0.
        direction = 0.
        proj_directions = []
        for i, trg in enumerate(target_encodings):
            dist_i, proj_feat = proj_subspace(mix_feat.unsqueeze(0).float(), trg.permute(1, 0).unsqueeze(0), self.args.ndim) 
            dist += dist_i * self.args.alpha[i]

            proj_direction = (proj_feat.squeeze(0) - src_feat) * self.args.alpha[i]
            proj_directions.append(proj_direction)
        
        proj_direction = torch.stack(proj_directions).sum(dim=0)
        proj_direction = proj_direction / proj_direction.norm(dim=-1, keepdim=True)
        direction = (1. - torch.nn.CosineSimilarity()(edit_direction, proj_direction)).mean()

        return self.args.w_dist * dist, self.args.w_direction * direction

    def forward(
        self,
        styles,
        return_latents=False,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):

        with torch.no_grad():
            if input_is_latent:
                w_styles = styles
            else:
                w_styles = self.generator_frozen.style(styles)

            frozen_img, _ = self.generator_frozen(w_styles, return_latents=True, input_is_latent=True, truncation=truncation, randomize_noise=randomize_noise)

        trainable_img, _ = self.generator_trainable(w_styles, return_latents=True, input_is_latent=True, truncation=truncation, randomize_noise=randomize_noise)

        dist_loss = 0
        direct_loss = 0

        if self.args.swin_w:
            dist, direct = self.subspace_loss(trainable_img, frozen_img, self.swin_model, self.swin_encodings)
            dist_loss += dist * self.args.swin_w
            direct_loss += direct * self.args.swin_w
        if self.args.dino_w:
            dist, direct = self.subspace_loss(trainable_img, frozen_img, self.dino_model, self.dino_encodings)
            dist_loss += dist * self.args.dino_w
            direct_loss += direct * self.args.dino_w

        return [frozen_img, trainable_img], dist_loss, direct_loss        

    def pivot(self):
        par_frozen = dict(self.generator_frozen.named_parameters())
        par_train  = dict(self.generator_trainable.named_parameters())

        for k in par_frozen.keys():
            par_frozen[k] = par_train[k]
