import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPVisionModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from peft.tuners.lora import LoraLayer
from peft import LoraConfig, get_peft_model
from peft.utils import _get_submodules, ModulesToSaveWrapper
from tqdm import tqdm
import chess
import chess.pgn
import io
import cairosvg
from PIL import Image
from torchvision import transforms

IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<image>"

class CoIForGPTNeoXCausalLM(nn.Module):
    def __init__(self, model_args, training_args):
        super().__init__()
        self.compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
        self.compute_decice = training_args.device
        self.llm = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, torch_dtype=self.compute_dtype, device_map=training_args.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)  
        if model_args.add_image_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN], special_tokens=True)
            self.llm.resize_token_embeddings(len(self.tokenizer))
        self.IMAGE_TOKEN_INDEX=self.tokenizer(DEFAULT_IMAGE_TOKEN).input_ids[-1]
        self.encoder = CLIPVisionModel.from_pretrained(model_args.vision_tower_name, torch_dtype=self.compute_dtype, device_map=training_args.device)
        self.visual_proj = nn.Linear(self.encoder.config.hidden_size, self.llm.config.hidden_size, bias=False).to(dtype=self.compute_dtype, device=training_args.device) 
        self.use_patch = training_args.use_patch
        if training_args.lora_enable:
            self = self.lora_enable(self, training_args)
        self.requires_grad_layers()
    
    def lora_enable(self, model, training_args):
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear) and (name.startswith('llm') or name.startswith('encoder')) and 'embed_out' not in name:
                lora_module_names.add(name)
        all_linear_names= list(lora_module_names)

        lora_config = LoraConfig(
            r=training_args.lora_r,
            lora_alpha=training_args.lora_alpha,
            target_modules=all_linear_names,
            
            lora_dropout=training_args.lora_dropout,
            bias=training_args.lora_bias,
            task_type="CAUSAL_LM",
        )
        print("Adding LoRA adapters...")
        model = get_peft_model(model, lora_config)
        for name, module in model.named_modules():
            if isinstance(module, LoraLayer):
                if training_args.bf16:
                    module = module.to(torch.bfloat16)
        return model
    
    def requires_grad_layers(self):
        for name, param in self.named_parameters():
            if "embed_in" in name or "embed_out" in name or "visual_proj" in name or 'patch_embedding' :
                param.requires_grad_(True)
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(name, param.shape, param.dtype)

    def _replace_module(self, parent, child_name, new_module, child):
        setattr(parent, child_name, new_module)
        new_module.weight = child.weight
        if hasattr(child, "bias"):
            if child.bias is not None:
                new_module.bias = child.bias

        if getattr(child, "state", None) is not None:
            new_module.state = child.state
            new_module.to(child.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "lora_" in name:
                module.to(child.weight.device)
            if "ranknum" in name:
                module.to(child.weight.device)
                
    def merge_and_unload(self, merge=True, progressbar: bool = False):
        key_list = [key for key, _ in self.named_modules() if "lora" not in key]
        desc = "Unloading " + ("and merging " if merge else "") + "model"
        for key in tqdm(key_list, disable=not progressbar, desc=desc):
            try:
                parent, target, target_name = _get_submodules(self, key)
            except AttributeError:
                continue
            if isinstance(target, LoraLayer):
                bias = target.bias is not None
                new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
                if merge:
                    target.merge()
                self._replace_module(parent, target_name, new_module, target)

            # save any additional trainable modules part of `modules_to_save`
            if isinstance(target, ModulesToSaveWrapper):
                setattr(parent, target_name, target.modules_to_save[target.active_adapter])

        return self
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        return_dict = return_dict if return_dict is not None else self.llm.config.use_return_dict
        if inputs_embeds is None:
            inputs_embeds = self.prepare_inputs_embeds(input_ids, images)
        outputs = self.llm.gpt_neox(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = outputs[0]
        lm_logits = self.llm.embed_out(hidden_states)

        lm_loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(lm_logits.device)
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shift_logits = lm_logits[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = CrossEntropyLoss()
            lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((lm_loss,) + output) if lm_loss is not None else output

        return CausalLMOutputWithPast(
            loss=lm_loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=hidden_states,
            attentions=outputs.attentions,
        )
        
    def prepare_inputs_embeds(self, input_ids, images=None):
        text_embeds = self.llm.gpt_neox.embed_in(input_ids)
        img_embeds = None
        image_token_indices = torch.where(input_ids == self.IMAGE_TOKEN_INDEX)
        if images is not None and len(image_token_indices[0]):
            if self.use_patch:
                img_embeds = self.encoder(images.to(self.compute_dtype)).last_hidden_state[:,1:,:]
                img_embeds = img_embeds.contiguous().view(-1,img_embeds.shape[-1])
            else:
                img_embeds = self.encoder(images.to(self.compute_dtype)).pooler_output
            img_embeds = self.visual_proj(img_embeds)
            text_embeds[image_token_indices]=img_embeds
        return text_embeds
    
    def generate(self, input_ids, past_key_values=None, max_output_len=100):
        board = chess.Board()
        output_input_ids = [input_ids]
        generate_an_image = False
        image = None
        while input_ids[:,-1].item() != self.tokenizer(self.tokenizer.eos_token).input_ids[-1] and len(output_input_ids)<max_output_len:
            if generate_an_image:
                img_emb = transforms.ToTensor()(image)
                inputs_embeds = self.encoder(img_emb[None].to(self.compute_decice, self.compute_dtype))
                if self.use_patch:
                    inputs_embeds = inputs_embeds.last_hidden_state[:,1:,:].contiguous()
                    inputs_embeds = inputs_embeds.view(-1, inputs_embeds.shape[-1])
                else:
                    inputs_embeds = inputs_embeds.pooler_output
                inputs_embeds = self.visual_proj(inputs_embeds)[None]
            else:
                inputs_embeds = self.llm.gpt_neox.embed_in(input_ids)
            output = self.forward(inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=True)
            input_ids = output.logits[:,-1:].argmax(dim=-1)
            output_input_ids.append(input_ids)
            generate_an_image = input_ids.item()==self.IMAGE_TOKEN_INDEX
            past_key_values=output.past_key_values
            
            
            if generate_an_image:
                current_text = self.tokenizer.batch_decode(torch.cat(output_input_ids, dim=-1))[0]
                print(current_text)
                try:
                    fen = current_text.split('format')[1].split('and')[0].strip()
                    board = chess.Board(fen)
                    board_svg = chess.svg.board(board,size=224, coordinates=None)
                    outputfile = open(f'eval.svg', "w")
                    outputfile.write(board_svg)
                    outputfile.close()
                    svg_data =  io.BytesIO(cairosvg.svg2png(bytestring=board_svg))
                    image = Image.open(svg_data).convert('RGB')
                except:
                    return current_text, None

        output_input_ids = torch.cat(output_input_ids, dim=-1)
        output_text = self.tokenizer.batch_decode(output_input_ids)[0]
        return output_text, image
    
    def prepare_inputs_for_generation(self, **kargs):
        return self.llm.prepare_inputs_for_generation(**kargs)