import torch
import os
import copy
from transformers import BitsAndBytesConfig, ChameleonProcessor
import pytorch_lightning as pl
from pytorch_lightning.strategies import DeepSpeedStrategy
from deepspeed.ops.adam import DeepSpeedCPUAdam
from pytorch_lightning.loggers import WandbLogger
import peft
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import apply_activation_checkpointing, checkpoint_wrapper
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get('LOCAL_RANK', '-1')
from data.data_OpenHermes import OpenHermesDataModule

from pytorch_lightning.callbacks import ModelCheckpoint
from datetime import datetime
import shutil
import torch.distributed as dist
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer
from PIL import Image
import torch.nn.functional as F
import sys
sys.path.append(os.path.abspath(__file__).rsplit("/", 1)[0])
# import pdb
from pytorch_lightning.utilities import rank_zero_only
import torch.nn as nn
import wandb

import numpy as np
from scipy.linalg import orthogonal_procrustes
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.linalg import orthogonal_procrustes

from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from collections import OrderedDict
from peft import PeftModel, PeftConfig, get_peft_model

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
# Remove OMPI-related environment variables if they exist (!important for amlt)
for var in ["OMPI_COMM_WORLD_LOCAL_RANK", "OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]:
    if var in os.environ:
        del os.environ[var]
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_RESUME"] = "never"

import re
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from typing import Union

            
import torch
from torch import nn
from copy import deepcopy

def compute_accuracy(logits, labels):
    # Mask out invalid tokens (e.g., -100)
    valid_mask = labels >= 0

    # Apply softmax to logits
    probs = nn.functional.softmax(logits, dim=-1)

    # Compute Top-1 accuracy
    top1_pred = torch.argmax(probs, dim=-1)
    correct_top1 = (top1_pred[valid_mask] == labels[valid_mask]).float().mean()

    # Compute Top-5 accuracy
    _, top5_preds = torch.topk(probs, 5, dim=-1)
    correct_top5 = top5_preds[valid_mask].eq(labels[valid_mask].unsqueeze(-1)).any(dim=-1).float().mean()

    return correct_top1.item() * 100, correct_top5.item() * 100

def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
    reduction = "sum" if num_items_in_batch is not None else "mean"
    loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
    if reduction == "sum":
        loss = loss / num_items_in_batch
    return loss

def ForCausalLMLoss(
    logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
    top_1_acc, top_5_acc = compute_accuracy(shift_logits, shift_labels)
    return loss, top_1_acc, top_5_acc



def group_all_gather(tensor, group, group_size, group_rank=-1, dim=-1):
    return GroupAllGather.apply(tensor, dim, group, group_size, group_rank)

class GroupAllGather(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx, tensor: torch.Tensor, dim: int, group, group_size: int, group_rank: int
    ):
        if group_rank == -1:
            assert group is not None
            if hasattr(dist, "get_group_rank"):
                group_rank = dist.get_group_rank(group, dist.get_rank())
            else:
                group_rank = dist.get_rank(group)
        ctx.group_rank = group_rank
        ctx.group_size = group_size
        ctx.dim = dim
        ctx.group = group
        tensor_list = [torch.empty_like(tensor) for _ in range(group_size)]
        dist.all_gather(tensor_list, tensor, group=group)
        gathered = torch.cat(tensor_list, dim=dim)
        return gathered

    @staticmethod
    def backward(ctx, gathered_grad: torch.Tensor):
        group_rank = ctx.group_rank
        group_size = ctx.group_size
        dim = ctx.dim
        group = ctx.group
        gathered_grad = gathered_grad / group_size
        grad_list = list(gathered_grad.chunk(group_size, dim))
        grad_tensor = torch.empty_like(grad_list[group_rank])  # placeholder
        dist.reduce_scatter(grad_tensor, grad_list, group=group)
        return grad_tensor, None, None, None, None
    


def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['vqmodel']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            # print(f"find linear module: {name}")
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    # print all the lora_module_names
    print("lora_module_names: ", lora_module_names)
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

@rank_zero_only
def copy_src_folder(src_dir, dest_dir, ignore_ckpts):
    # Place your logic to copy source folder here
    print(f"Copying src folder to {dest_dir}")
    shutil.copytree(src_dir, dest_dir, ignore=ignore_ckpts, dirs_exist_ok=True)

def apply_activation_checkpointing(model):
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.TransformerEncoderLayer, torch.nn.TransformerDecoderLayer)):
            print(f"Applying activation checkpointing to: {name}")
            model._modules[name] = checkpoint_wrapper(module)


class QwenLightningModule(pl.LightningModule):
    def __init__(self, model_name_or_path, processor, args):
        super().__init__()
        self.all_text_embeds = []
        self.all_image_embeds = []
        self.save_hyperparameters(args)
        self.args = args
        self.ce = F.cross_entropy
        self.groups = None
        
        # Load custom model with optional 4-bit quantization
        if not args.use_4bit:
            if args.use_peft:
                self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    model_name_or_path,
                    torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
                    _attn_implementation="sdpa",
                    low_cpu_mem_usage=True,
                )
                self.model.train()
            else:
                self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                    model_name_or_path,
                    torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
                    _attn_implementation="sdpa", # or eager
                )
        
        if args.use_peft:
            peft_config = LoraConfig(
                r=8,
                lora_alpha=32,
                target_modules=['q_proj', 'gate_proj', 'o_proj', 'k_proj', 'down_proj', 'v_proj', 'up_proj'],    
                lora_dropout=0.1,
                bias="none",
                task_type="CAUSAL_LM",
            )
            self.model = get_peft_model(self.model, peft_config, autocast_adapter_dtype=False)
            # Freeze all parameters except those in lm_head and LoRA layers
            for name, param in self.model.named_parameters():
                # LoRA parameters have names like lora_A/lora_B, and we want to keep them trainable
                # We also keep lm_head parameters trainable for full fine-tuning
                if "lm_head" not in name and "lora_" not in name:
                    param.requires_grad = False
                elif "visual.patch_embed." in name:
                    param.requires_grad = False
                else:
                    param.requires_grad = True
                    print(f"✅ trainable param: {name}")
            if self.args.tune_mm_mlp:
                for n, p in self.model.visual.merger.named_parameters():
                    p.requires_grad = True
                    print(f"✅ trainable param:: {n}")
            for n, p in self.model.lm_head.named_parameters():
                p.requires_grad = False
            self.model.train()

        else:
            self.set_model()
        self.processor = processor


    def set_model(self):
        if self.args.tune_mm_vision:
            for n, p in self.model.visual.named_parameters():
                p.requires_grad = True
               
        else:
            for n, p in self.model.visual.named_parameters():
                p.requires_grad = False

        if self.args.tune_mm_mlp:
            for n, p in self.model.visual.merger.named_parameters():
                p.requires_grad = True
               
        else:
            for n, p in self.model.visual.merger.named_parameters():
                p.requires_grad = False

        if self.args.tune_mm_llm:
            for n, p in self.model.model.named_parameters():
                p.requires_grad = True
               
            for n, p in self.model.lm_head.named_parameters():
                p.requires_grad = True
              
        else:
            for n, p in self.model.model.named_parameters():
                p.requires_grad = False
            self.model.lm_head.requires_grad = False

   
    def forward(self, inputs):
        return self.model(**inputs)

    
    def training_step(self, batch, batch_idx):
        batch_size = args.batch_size
        if self.args.pure_text == False:
            inputs = batch["inputs_text_image"]
        else:
            inputs = batch["inputs_pure_text"]
        output = self.model(**inputs)
        logits = output.logits
        loss_ar, top1_acc, top5_acc  = self.cal_lm_loss(
            logits=logits,
            labels=batch["inputs_text_image"]["labels"] if self.args.pure_text == False else batch['inputs_pure_text']['labels'],
            ignore_index=-100,
            vocab_size=logits.shape[-1],
        )
        loss = loss_ar

        if self.logger is not None:
            self.logger.log_metrics({
                "ar_loss": loss_ar.item(),
                "learning_rate": self.trainer.optimizers[0].param_groups[0]['lr'],
                "top1_acc": top1_acc,
                "top5_acc": top5_acc,
            }, step=self.global_step)
        
        self.log("ar_loss", loss_ar, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=batch_size)
       
        return loss

    def configure_optimizers(self):
        optimizer = DeepSpeedCPUAdam(self.model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        return optimizer
    
    def cal_lm_loss(
        self, 
        logits,
        labels,
        ignore_index,
        vocab_size,
        is_image=False,
    ):
        return ForCausalLMLoss(
            logits=logits,
            labels=labels,
            ignore_index=ignore_index,
            vocab_size=vocab_size,
        )

def main(args):
    seed = 42
    pl.seed_everything(seed)  
    processor = AutoProcessor.from_pretrained(
        args.model_path,
    )  
    
    data_module = OpenHermesDataModule(
                    data_path="/storage/OpenHermes-2.5/", 
                    vl_chat_processor = processor, 
                    batch_size=args.batch_size, 
                    max_length=4096, 
                    font_path=args.font_path,
                    font_size=args.font_size,
                    n_parts=args.n_parts,
                    clip_token_num=args.clip_token_num,     
                    cache_dir=args.cache_dir,
                    pure_text=args.pure_text
                )

    if args.use_wandb:
        wandb.finish()
        wandb_logger = WandbLogger(project=args.wandb_project, name=args.wandb_run_name, log_model="all", resume=False, id = wandb.util.generate_id())
    else:
        wandb_logger = None
    lightning_model = QwenLightningModule(args.model_path, processor, args)
    # add data time to output_dir
    args.output_dir = os.path.join(args.output_dir, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    # Ensure output_dir exists
    os.makedirs(args.output_dir, exist_ok=True)

    # ------------------- Copy src Folder ------------------- #
    # Define an ignore function to skip 'ckpts' folder
    def ignore_ckpts(path, names):
        # If 'ckpts' is in the directory listing, skip it
        ignored = []
        if 'ckpts' in names:
            ignored.append('ckpts')
        return ignored

    if args.save_src:
        src_dir = "src"
        dest_dir = os.path.join(args.output_dir, "src")
        if os.path.exists(src_dir):
            try:
                copy_src_folder(src_dir, dest_dir, ignore_ckpts)
            except Exception as e:
                print(f"Error copying src folder: {e}")
    # ------------------------------------------------------------------- #

    # Define custom checkpoint callback
    checkpoint_callback = ModelCheckpoint(
        dirpath=args.output_dir,  # Directory to save checkpoints
        filename="Qwenvl_2_5_3B-{epoch:02d}-{step:05d}",  # Custom checkpoint filename
        save_top_k=3,  # Save the top 3 best models based on monitored metric
        monitor="ar_loss",  # Monitor training loss (or specify validation metric)
        mode="min", 
        save_weights_only=False,  
        every_n_epochs=1,
        save_last=True,  
    )

    trainer = pl.Trainer(
        max_epochs=args.epochs,
        strategy=DeepSpeedStrategy(config=args.ds_config),
        logger=wandb_logger,
        precision="bf16" if args.mixed_precision == "bf16" else "16" if args.mixed_precision == "fp16" else 32,
        log_every_n_steps=10,
        accumulate_grad_batches=args.grad_accum_steps,
        accelerator="gpu",
        # devices="auto",
        num_nodes = int(os.environ.get("WORLD_SIZE", 1)) // int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
        devices = torch.cuda.device_count(),
        default_root_dir=args.output_dir,  # Save checkpoints in the output_dir
        callbacks=[checkpoint_callback],  # Add the checkpoint callback
    )


    trainer.fit(lightning_model, datamodule=data_module)
  

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # parser.add_argument("--json-file", type=str, required=True, help="Path to the JSON file with data")
    parser.add_argument("--ds-config", type=str, help="Path to the DeepSpeed config file")
    parser.add_argument("--model-path", type=str, required=True, help="Path to pretrained Chameleon model")
    parser.add_argument("--output-dir", type=str, required=True, help="Path to save the output")
    parser.add_argument("--batch-size", type=int, default=8, help="Batch size for training")
    parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
    parser.add_argument("--weight-decay", type=float, default=0.1, help="Weight decay")
    parser.add_argument("--num-workers", type=int, default=4, help="Number of data loader workers")
    parser.add_argument("--max-length", type=int, default=2048, help="Maximum sequence length")
    parser.add_argument("--mixed-precision", type=str, default="bf16", choices=["none", "fp16", "bf16"], help="Precision for training")
    parser.add_argument("--grad-accum-steps", type=int, default=1, help="Number of gradient accumulation steps")
    parser.add_argument("--use-4bit", action="store_true", help="Use 4-bit quantization for the model")
    parser.add_argument("--use-peft", action="store_true", help="Enable PEFT for fine-tuning.")
    parser.add_argument('--local_rank', type=int, default=-1,
                        help='Local rank for distributed training')
    parser.add_argument('--deepspeed', action='store_true',
                        help='Enable DeepSpeed')
    parser.add_argument("--use-wandb", action="store_true", help="Enable WandB logging.")
    parser.add_argument("--use-kl", action="store_true", help="Enable kl loss.")
    parser.add_argument('--wandb-project', type=str,
                        help='Weights & Biases project name')
    parser.add_argument('--wandb-run-name', type=str,
                        help='Weights & Biases run name')
    parser.add_argument('--save-src', action='store_true',
                        help='Save the src folder to the output_dir if it exists')
   
    parser.add_argument('--pure-text', action='store_true', help='pure text or text image')
    parser.add_argument("--train-file", type=str, required=True, help="Path to the training file")
    parser.add_argument("--font-path", type=str, default="/storage/GoNotoCurrent.ttf", help="Font path for training")
    parser.add_argument("--font-size", type=int, default=7, help="Font size for training")
    parser.add_argument("--n-parts", type=int, default=1, help="Number of parts for training")
    parser.add_argument("--cache-dir", type=str, default="/storage", help="Cache directory for training")
    parser.add_argument("--contrastive-gather-way", type=str, default="single_node", choices=["all_nodes", "single_node", "None"], help="Contrastive gather way")
    parser.add_argument("--contrastive-temperature", type=float, default=0.2, help="Contrastive temperature for training")
    parser.add_argument("--temperature_kl", type=float, default=1.0, help="Contrastive temperature for training")
    parser.add_argument("--lambda_kl", type=float, default=0.1, help="Contrastive temperature for training")
    parser.add_argument("--alpha_kl", type=float, default=0.2, help="Contrastive temperature for training")
    parser.add_argument('--tune_mm_llm', action='store_true', help='tuning')
    parser.add_argument('--tune_mm_vision', action='store_true', help='tuning')
    parser.add_argument('--tune_mm_mlp', action='store_true', help='tuning')
    parser.add_argument("--image-size",type=int,nargs=2, default=(896, 14), help="Image size (width height), e.g., 896 14")
    parser.add_argument("--pairs", type=str, help="language pairs")
    parser.add_argument("--do-predict", action="store_true", help="Enable test.")
    parser.add_argument("--max_test_samples", type=int, default=None, help="Chunk size for training")


    args = parser.parse_args()
    main(args)


