import os
import PIL.Image
import torch
import numpy as np
from transformers import AutoModelForCausalLM
from .modeling_vlm import MultiModalityCausalLM
from .processing_vlm  import VLChatProcessor
import PIL
from PIL import Image

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import LlamaForCausalLM



class JanusLLamaModel(MultiModalityCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.cfg = config

    def set_task2_loss_type(self, loss_type):
        self.loss_type = loss_type
        print('current task2 loss type is:', loss_type)

    def train_setup(self):
        for n, p in self.language_model.named_parameters():
            p.requires_grad = True
        self.language_model.train()
        self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
        
        for n, p in self.gen_embed.named_parameters():
            p.requires_grad = True
        self.gen_embed.train()
        
        for n, p in self.gen_head.named_parameters():
            p.requires_grad = True
        self.gen_head.train()
        
        for n, p in self.gen_aligner.named_parameters():
            p.requires_grad = True
        self.gen_aligner.train()
        
        for n, p in self.aligner.named_parameters():
            p.requires_grad = True
        self.aligner.train()
        
        for n, p in self.vision_model.named_parameters():
            p.requires_grad = False
        self.vision_model.eval()
        
        for n, p in self.gen_vision_model.named_parameters():
            p.requires_grad = False
        self.gen_vision_model.eval()
    
    @torch.no_grad()
    def edit_image(
        self,
        vl_chat_processor: VLChatProcessor,
        input_ids,
        attention_mask,
        image1,
        image_seq_mask,
        temperature: float = 1,
        parallel_size: int = 16,
        cfg_weight: float = 5,
        set_cfg=True,
        image_token_num_per_image: int = 576,
        img_size: int = 384,
        patch_size: int = 16,
        img_path:str=None,
        instruction=None,
    ):
        parallel_size = input_ids.shape[0] // 2
        

        image_embeds, _ = self.prepare_embedding(image1)
        input_ids[input_ids < 0] = 0  # ignore the image embeddings
        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
        for kk in range(inputs_embeds.shape[0]):
            inputs_embeds[kk][image_seq_mask[kk]] = image_embeds[kk]
        
        # attention_mask = torch.cat((attention_mask, torch.ones((B, L)).long().to(attention_mask.device)), dim=1)
        # inputs_embeds = self.language_model.get_input_embeddings()(tokens)

        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
        B = attention_mask.shape[0]
        from tqdm import tqdm
        for i in tqdm(range(image_token_num_per_image)):
            outputs = self.language_model.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state
            
            logits = self.gen_head(hidden_states[:, -1, :])
            if set_cfg:
                logit_cond = logits[0::2, :]
                logit_uncond = logits[1::2, :]
            
                logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
            
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)
            
            
            if set_cfg:
                next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = self.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones(B, 1).to(attention_mask)], dim=1)
        
        dec = self.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec
        final_imgs = [Image.fromarray(img) for img in visual_img]

        return generated_tokens, final_imgs, attention_mask

    @torch.no_grad()
    def generate_image_2(
        self,
        vl_chat_processor: VLChatProcessor,
        # set_cfg: True,
        input_ids,
        attention_mask,
        image1,
        image_seq_mask,
        temperature: float = 1,
        parallel_size: int = 16,
        cfg_weight: float = 5,
        image_token_num_per_image: int = 576,
        img_size: int = 384,
        patch_size: int = 16,
        img_path:str=None,
        instruction=None,
    ):
        image_embeds, _ = self.prepare_embedding(image1)
        input_ids[input_ids < 0] = 0  # ignore the image embeddings
        input_embeds = self.language_model.get_input_embeddings()(input_ids)
        for kk in range(input_embeds.shape[0]):
            input_embeds[kk][image_seq_mask[kk]] = image_embeds[kk]
        
        attention_mask = torch.cat((attention_mask, torch.ones((B, L)).long().to(attention_mask.device)), dim=1)
        
        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
        B = attention_mask.shape[0]

        for i in range(image_token_num_per_image):
            outputs = self.language_model.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state
            
            logits = self.gen_head(hidden_states[:, -1, :])
            
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)
            
            img_embeds = self.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones(B, 1).to(attention_mask)], dim=1)
        
        dec = self.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec
        
        

        if img_path:
            if dist.get_rank() % torch.cuda.device_count() == 0:
                os.makedirs(img_path, exist_ok=True)
                self.save_stack_images(visual_img, batch_size=visual_img.shape[0], save_path=os.path.join(img_path,f'{cur_step}.png'))
                if instruction is not None:
                    f = open(os.path.join(img_path,f'{cur_step}.txt'),'w')
                    lastins = ''
                    for inss in instruction:
                        if inss!=lastins:
                            f.write(inss+'\n')
                            lastins = inss
            # for k in range(parallel_size):
            #     PIL.Image.fromarray(visual_img[k]).save(f'../../gen_img/{cur_step}-{k}.png')
        final_imgs = [Image.fromarray(img) for img in visual_img]
        
        

        return generated_tokens, final_imgs, attention_mask
    
    
    @torch.no_grad()
    def generate_image(
        self,
        vl_chat_processor: VLChatProcessor,
        # set_cfg: True,
        input_ids,
        attention_mask,
        set_cfg=True,
        cur_step=0,
        temperature: float = 1,
        parallel_size: int = 16,
        cfg_weight: float = 5,
        image_token_num_per_image: int = 576,
        img_size: int = 384,
        patch_size: int = 16,
        img_path:str=None,
        instruction=None,
    ):
        # input_ids = vl_chat_processor.tokenizer.encode(prompt)
        # input_ids = torch.LongTensor(input_ids)
        parallel_size = input_ids.shape[0]
        if set_cfg:
            # tokens = torch.zeros((parallel_size*2, input_ids.shape[1]), dtype=torch.int).cuda()
            tokens = torch.repeat_interleave(input_ids,2,dim=0)
            for i in range(tokens.size(0)):
                # tokens[i, :] = input_ids[i, :]
                if i % 2 != 0:
                    pad_list = torch.where(tokens[i]==vl_chat_processor.pad_id)[0]
                    if pad_list.shape[0]==0:
                        st = 1
                    else:
                        st = pad_list[-1].item()+2
                    tokens[i, st:-1] = vl_chat_processor.pad_id
                    
            attention_mask = torch.repeat_interleave(attention_mask, 2, dim=0) 
        else:
            tokens = input_ids

        
        
        inputs_embeds = self.language_model.get_input_embeddings()(tokens)
        
        
        generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
        B = attention_mask.shape[0]
        from tqdm import tqdm
        for i in tqdm(range(image_token_num_per_image)):
            outputs = self.language_model.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
            hidden_states = outputs.last_hidden_state
            
            logits = self.gen_head(hidden_states[:, -1, :])
            if set_cfg:
                logit_cond = logits[0::2, :]
                logit_uncond = logits[1::2, :]
            
                logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
            
            probs = torch.softmax(logits / temperature, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)
            
            
            if set_cfg:
                next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
            img_embeds = self.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones(B, 1).to(attention_mask)], dim=1)
            
            


        dec = self.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
        dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

        dec = np.clip((dec + 1) / 2 * 255, 0, 255)

        visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
        visual_img[:, :, :] = dec
        
        

        if img_path:
            if dist.get_rank() % torch.cuda.device_count() == 0:
                os.makedirs(img_path, exist_ok=True)
                self.save_stack_images(visual_img, batch_size=visual_img.shape[0], save_path=os.path.join(img_path,f'{cur_step}.png'))
                if instruction is not None:
                    f = open(os.path.join(img_path,f'{cur_step}.txt'),'w')
                    lastins = ''
                    for inss in instruction:
                        if inss!=lastins:
                            f.write(inss+'\n')
                            lastins = inss
            # for k in range(parallel_size):
            #     PIL.Image.fromarray(visual_img[k]).save(f'../../gen_img/{cur_step}-{k}.png')
        final_imgs = [Image.fromarray(img) for img in visual_img]
        
        

        return generated_tokens, final_imgs, (tokens, attention_mask)
    
    def save_stack_images(self, images: np.ndarray, batch_size: int, save_path: str, height: int = 384, weight: int = 384):
        blank_image = np.zeros((height, weight, 3), dtype=np.uint8)

        image_num_per_row = int(np.sqrt(batch_size).item())
        image_num_per_column = int(np.ceil(batch_size / image_num_per_row).item())

        images_to_padding = image_num_per_row * image_num_per_column - batch_size
        
        if images_to_padding != 0:
            images = np.concatenate([images, [blank_image] * images_to_padding], axis=0)

        rows = []
        for idx in range(0, image_num_per_row * image_num_per_column, image_num_per_row):
            row = np.hstack(images[idx:idx+image_num_per_row])
            rows.append(row)
        combined_image = np.vstack(rows)

        pil_image = Image.fromarray(combined_image)
        pil_image.save(save_path)



# # specify the path to the model
# model_path = "deepseek-ai/Janus-1.3B"
# vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
# tokenizer = vl_chat_processor.tokenizer

# vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
#     model_path, trust_remote_code=True
# )
# vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()

# conversation = [
#     {
#         "role": "User",
#         "content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair",
#     },
#     {"role": "Assistant", "content": ""},
# ]

# sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
#     conversations=conversation,
#     sft_format=vl_chat_processor.sft_format,
#     system_prompt="",
# )
# prompt = sft_format + vl_chat_processor.image_start_tag


# @torch.inference_mode()
# def generate(
#     mmgpt: MultiModalityCausalLM,
#     vl_chat_processor: VLChatProcessor,
#     prompt: str,
#     temperature: float = 1,
#     parallel_size: int = 16,
#     cfg_weight: float = 5,
#     image_token_num_per_image: int = 576,
#     img_size: int = 384,
#     patch_size: int = 16,
# ):
#     input_ids = vl_chat_processor.tokenizer.encode(prompt)
#     input_ids = torch.LongTensor(input_ids)

#     tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
#     for i in range(parallel_size*2):
#         tokens[i, :] = input_ids
#         if i % 2 != 0:
#             tokens[i, 1:-1] = vl_chat_processor.pad_id

#     inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

#     generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

#     for i in range(image_token_num_per_image):
#         outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
#         hidden_states = outputs.last_hidden_state
        
#         logits = mmgpt.gen_head(hidden_states[:, -1, :])
#         logit_cond = logits[0::2, :]
#         logit_uncond = logits[1::2, :]
        
#         logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
#         probs = torch.softmax(logits / temperature, dim=-1)

#         next_token = torch.multinomial(probs, num_samples=1)
#         generated_tokens[:, i] = next_token.squeeze(dim=-1)

#         next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
#         img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
#         inputs_embeds = img_embeds.unsqueeze(dim=1)


#     dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
#     dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

#     dec = np.clip((dec + 1) / 2 * 255, 0, 255)

#     visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
#     visual_img[:, :, :] = dec

#     os.makedirs('generated_samples', exist_ok=True)
#     for i in range(parallel_size):
#         save_path = os.path.join('generated_samples', "img_{}.jpg".format(i))
#         PIL.Image.fromarray(visual_img[i]).save(save_path)


# generate(
#     vl_gpt,
#     vl_chat_processor,
#     prompt,
# )