from math import inf
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torchvision import models as tv
from torch.nn.parallel import DistributedDataParallel
from torchvision.utils import save_image
import os
from PIL import Image, ImageDraw, ImageFont
from lpips.pretrained_networks import alexnet
from copy import deepcopy
from collections import OrderedDict
import moxing as mox
import sys
sys.path.append(".")
from mimogpt.models.selftok.sd3.sd3_impls import SDVAE, SD3LatentFormat
from mimogpt.models.selftok.sd3.rectified_flow import RectifiedFlow
from mimogpt.models.selftok.image_tokenizer import ImageTokenizer
import matplotlib.pyplot as plt
from tqdm import tqdm
from diffusers.models import AutoencoderKL
import pdb
from copy import deepcopy

from easydict import EasyDict

from mimogpt.utils import read_from_yaml


def parse_args_from_yaml(yml_path):
    config = read_from_yaml(yml_path)
    config_obj = EasyDict(config)
    return config_obj

def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag
        
def load_state(model, prefix, state_dict):
    model_dict = model.state_dict()  
    pretrained_dict = {k.replace(prefix,''): v for k, v in state_dict.items() if k.replace(prefix,'') in model_dict}  
    dict_t = deepcopy(pretrained_dict)
    for key, weight in dict_t.items():
        if key in model_dict and model_dict[key].shape != dict_t[key].shape:
            pretrained_dict.pop(key)
    m, u = model.load_state_dict(pretrained_dict, strict=False)
    
def set_sd3_vae(vae_path):
    vae = SDVAE(device="cpu", dtype=torch.bfloat16)
    state_dict = torch.load(vae_path, map_location='cpu')
    load_state(vae, 'first_stage_model.', state_dict)
    vae.cuda()
    vae.eval()
    return vae

def set_flux_vae(vae_path):
    vae = AutoencoderKL.from_pretrained(vae_path)

    vae.cuda()
    vae.eval()
    return vae

def set_ema_model(model):
    ema = deepcopy(model).to(torch.float32)  # Create an EMA of the model for use after training
    requires_grad(ema, False)
    update_ema(ema, model, decay=0)
    ema = ema.cuda()
    ema.eval()
    return ema

def norm_ip(img, low, high):
    img.clamp_(min=low, max=high)
    img.sub_(low).div_(max(high - low, 1e-5))
    
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        name = name.replace("module.", "")
        
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
        
class local_alexnet(alexnet):
    def __init__(self, path):
        super().__init__(requires_grad=False, pretrained=False)
        tv_alexnet = tv.alexnet()
        tv_alexnet.load_state_dict(torch.load(path))
        alexnet_pretrained_features = tv_alexnet.features
        for x in range(2):
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(10, 12):
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
        for param in self.parameters():
            param.requires_grad = False
                

class ImageEval:
    def __init__(self, cfg, ckpt_path, download_ckpt=True, datatype='256', start = 1.0, 
                 cfg_scale = 1, tmp_local_ckpt_path = '', 
                 model_type='sd3', lognorm_schedule=False, ema_decoder=False, **kwargs):
        
        rank = dist.get_rank()
        self.cfg = cfg
        self.datatype = datatype
        self.model_type = model_type
        # define models
        if self.model_type == 'sd3':
            self.vae = set_sd3_vae(cfg.common.vae_path)
        elif self.model_type == 'flux':
            self.vae = set_flux_vae(cfg.common.vae_path)
        else:
            raise ValueError(f"Unsupported MODEL_TYPE: {self.model_type}. Expected 'sd3' or 'flux'.")
        cfg.tokenizer.params.noise_schedule_config.is_eval=cfg.common.is_eval
        self.model = ImageTokenizer(**cfg.tokenizer.params)
        self.model.set_eval()

        if datatype == '256':
            shift = 1.0
        elif datatype == '512':
            shift = 1.818

        self.ema_decoder = ema_decoder
        if self.ema_decoder==True:
            self.ema = set_ema_model(self.model.model)
            self.ema.eval()
        self.vae.eval()
        self.diti = self.model.diti
        self.K = self.diti.K
        self.count = 0
        self.count_cfg = 0
        self.start = start
        self.cfg_scale = cfg_scale
        if hasattr(cfg.tokenizer.params, "cut_of_k") and self.cfg.tokenizer.params.cut_of_k:
            self.cut_of_k = self.cfg.tokenizer.params.cut_of_k
        else:
            self.cut_of_k = None
        
        # load ckpt
        os.makedirs('/cache/model', exist_ok=True)
        self._local_ckpt = ckpt_path
        
        if download_ckpt and rank == 0 and not os.path.exists(self._local_ckpt):
            print(f'download ckpt {ckpt_path}')
            mox.file.copy(ckpt_path, self._local_ckpt)
        
        print(f'using ckpt {ckpt_path}')
        
        state_dict = torch.load(self._local_ckpt, map_location="cpu")
        
        if self.ema_decoder==True:
            self.ema.load_state_dict(state_dict['ema_state_dict'])
        self.model.load_state_dict(state_dict['state_dict'],strict=False)
        requires_grad(self.model, flag=False)

        # set eval-specific params
        self._steps = 50
        self.flow = RectifiedFlow(
            self._steps, self.start, self.cut_of_k, val_schedule='uniform', shift=shift, **cfg.tokenizer.params.noise_schedule_config,
        )
        requires_grad(self.flow, flag=False)
        self.cond_vary = True
        self.saved_images = 8

        # set device
        if self.ema_decoder==True:
            self.ema.cuda()
        self.model.cuda()
        
        self.lognorm_schedule = lognorm_schedule

    def process_input(self, batch):
        if self.datatype.isnumeric() or self.datatype == 'extract_tokens':
            images = batch.cuda()
            x0 = self.vae.encode(images)
            if self.model_type == 'sd3':
                x0 = SD3LatentFormat().process_in(x0)
            
            dec_in = x0
            enc_in = x0
            return dec_in, enc_in, images
        else:
            images, enc_in = batch
            images = images.cuda()
            enc_in = enc_in.cuda()
            if self.model_type == 'sd3':          
                dec_in = SD3LatentFormat().process_in(self.vae.encode(images))
                enc_in = SD3LatentFormat().process_in(self.vae.encode(enc_in))
            
            return dec_in, enc_in, images

    def clean_up(self, remove_ckpt=False):
        del self.model
        if self.ema_decoder==True:
            del self.ema
        if remove_ckpt:
            if dist.get_rank() == 0 and os.path.exists('/cache/model/pretrained.pth'):
                os.remove('/cache/model/pretrained.pth')
    
    
    @torch.no_grad()
    def eval_indices(self, ind, curiter=0, instruction=None, **kwargs):
        encoder = self.model.module.encoder if hasattr(self.model, 'module') else self.model.encoder
        
        device = ind.device
        x = ind - kwargs['text_vocab_size'] # [B, 512]
        
        B = x.shape[0]
        indices = x.reshape(-1,1)
        
        outs_q = encoder.quantizer.get_output_from_indices(indices)
        outs_q = outs_q.reshape(B, -1, outs_q.shape[-1])
        
        if encoder.post_norm:
            outs_q = encoder.final_layer_norm3(outs_q)
        
        if hasattr(self.cfg.tokenizer.params, "stages"):
            t_mapped = torch.tensor([1000]*B, device=device).long()
        else:
            t_mapped = torch.tensor([(self.flow.timestep_map[0])/1000.0]*B, device=device)
        k = self.diti.to_indices(t_mapped)
        
        d=k
        
        enc_mask = encoder.get_encoder_mask(x, d)
        attn_mask = enc_mask
        mask_v = enc_mask[..., None].expand_as(outs_q)
        encoder_hidden_states = outs_q.to(device) * mask_v.to(device)

        model_kwargs = dict(
            encoder_hidden_states=encoder_hidden_states,
            mask=attn_mask,
        )
        
        pred_x0, _ = self.model.model(**model_kwargs)
        pred_x0 = SD3LatentFormat().process_out(pred_x0)
        images = self.vae.decode(pred_x0)

        norm_ip(images, -1, 1)
        
        allimgs = []
        for idx, img in enumerate(images):
            img = img.permute(1, 2, 0).cpu().numpy()
            img = img * 255
            pil_image = Image.fromarray(img.astype(np.uint8))
            allimgs.append(pil_image)
        
        if dist.get_rank() % torch.cuda.device_count() == 0:
            image_save_path = kwargs['image_save_path']
            os.makedirs(image_save_path, exist_ok=True)
            
            txt_path = os.path.join(image_save_path, f'{curiter}.txt')
            f = open(txt_path,'w')
            lastins = ''
            for inss in instruction:
                if inss != lastins:
                    f.write(inss+'\n')
                    lastins = inss
                    
            save_image(images, os.path.join(image_save_path, f'{curiter}.png'), nrow=4, normalize=True, value_range=(0, 1))

        return allimgs
    
    @torch.no_grad()
    def eval_indices_token(self, npy_path, save_folder='image', token_num_per_file=4, batch_size=32, draw_text=False, txt_path=None):
        file_batch_size = batch_size // token_num_per_file
        
        npy_file_list = [f for f in os.listdir(npy_path) if 'npy' in f]
        image_tokens = sorted(npy_file_list, key=lambda x: int(x.split('.')[0]))
        
        file_num = len(image_tokens)
        
        steps = file_num // file_batch_size + 1
        
        for idx in tqdm(range(steps)):
            st = idx * file_batch_size
            ed = min(file_num, (idx + 1) * file_batch_size)
            
            # read the token
            input_ids = []
            for j in range(st, ed):
                npy_token = np.load(os.path.join(npy_path, image_tokens[j]))
                
                x = torch.tensor(npy_token-128256).cuda()  # (bsz,len)
                input_ids.append(x)
            
            input_ids = torch.cat(input_ids,dim=0)

            B = input_ids.shape[0]
            
            indices = input_ids.reshape(-1,1)
            
            encoder = self.model.module.encoder if hasattr(self.model, 'module') else self.model.encoder
            
            outs_q = encoder.quantizer.get_output_from_indices(indices)
            outs_q = outs_q.reshape(B, -1, outs_q.shape[-1])
            
            if encoder.post_norm:
                outs_q = encoder.final_layer_norm3(outs_q)
            
            if hasattr(self.cfg.tokenizer.params, "stages"):
                t_mapped = torch.tensor([1000]*B).long().cuda()
            else:
                t_mapped = torch.tensor([(self.flow.timestep_map[0])/1000.0]*B).cuda()    
            k = self.diti.to_indices(t_mapped)
            
            d=k
            
            enc_mask = encoder.get_encoder_mask(input_ids, d)
            attn_mask = enc_mask
            mask_v = enc_mask[..., None].expand_as(outs_q)
            encoder_hidden_states = outs_q.cuda() * mask_v.cuda()  

            model_kwargs = dict(
                encoder_hidden_states=encoder_hidden_states,
                mask=attn_mask,
            )

            pred_x0, _ = self.model.model(**model_kwargs)
            pred_x0 = SD3LatentFormat().process_out(pred_x0)
            images = self.vae.decode(pred_x0)

            norm_ip(images, -1, 1)
            
            for cnt, img in enumerate(images):
                img = img.permute(1, 2, 0).cpu().numpy()
                img = img * 255
                pil_image = Image.fromarray(img.astype(np.uint8))
                os.makedirs(os.path.join(npy_path, save_folder), exist_ok=True)
                
                if draw_text:
                    text_list = []
                    
                    with open(txt_path, 'r') as file:
                        for line in file.readlines():
                            text_list.append(line.strip())
                            
                    draw = ImageDraw.Draw(pil_image)
                    font = ImageFont.load_default()
                    
                    text_position = (10, 10)
                    text_color = (255, 0, 0)
                    
                    draw.text(text_position, text_list[cnt + st * token_num_per_file], font=font, fill=text_color)
                    
                pil_image.save(os.path.join(npy_path, save_folder, f'{str(cnt + st * token_num_per_file)}.png'))
            

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
