'''
TODO
1. set up initial framework for distillation - for now do it even on very small numbered examples
2. set up framework to blur objects and create entire datasets

'''

import config
import argparse
import os
import re
from transformers import AutoProcessor, BitsAndBytesConfig, AutoModelForImageTextToText ,TrainingArguments, Trainer, TrainerCallback
import torch
from convert_utils import *
#os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pandas as pd
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from functools import partial
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from processing_smolvlm import SmolVLMProcessor
import numpy as np
from model_w_cfg import SmolVLMForConditionalGeneration
from model_joint_learning import SmolVLMForConditionalGeneration as model_jl
from modules.utils.callback import LossLoggerCallback2, CustomTrainer



import config
import argparse
def parse_args():
    parser = argparse.ArgumentParser(description='Distillation of SmolVLM2-500M-Video-Instruct for poor-apperance sequence')

    parser.add_argument('--mode', type=str,
                                choices=[
                                    'cfg_cap_simple_extended_freeze',
                                    'cfg_cap_simple_extended', 
                                    'cfg_BB_simple_extended_freeze',
                                    'cfg_BB_simple_extended',
                                    
                                    
                                    'joint_learning_extended',
                                    'joint_learning_extended_mask',
                                    'joint_learning_extended_mask_emph',

                                    'joint_learning_extended_freeze',
                                    'joint_learning_extended_mask_freeze',
                                    'joint_learning_extended_mask_emph_freeze',

                                    
                                    'joint_learning_extended_mask_foc_freeze',
                                    'joint_learning_extended_foc_freeze',
                                    'joint_learning_extended_foc_hard_freeze',

                                    'joint_learning_extended_l2_freeze',
                                    'joint_learning_extended_mask_l2_freeze',





                                    ],
                                default = 'cfg_cap_simple_extended',
                                help='')
    parser.add_argument('--resume', action='store_true', help='use small model')

    parser.add_argument('--model-size', type=str,
                                choices=['500M', '2.2B', '2.2B_quant'],
                                default = '2.2B',
                                help='')
    
    args = parser.parse_args()
    config.get_config(args)

    args.save_dir = f'{args.paths["finetuned_models"]}/{args.mode}/{args.model_size}'

    if True:
        args.model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" if "500M" in args.model_size  else "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
    else:
        args.model_id = "finetuned_models/uniform_blur/2.2B/checkpoint-2445"

    os.makedirs(args.save_dir, exist_ok=True)
    return args




class LossLoggerCallback(TrainerCallback):
    def __init__(self, log_file="training_loss.txt"):
        self.log_file = log_file

    def on_log(self, args, state, control, model=None,  logs=None, **kwargs):
        #for name, param in model.named_parameters():
        #    print(name, param.grad is not None, None if param.grad is None else param.grad.norm().item())
        if logs is not None and "loss" in logs:
            with open(self.log_file, "a") as f:
                f.write(f"Step {state.global_step}: loss = {logs['loss']:.4f}\n")



def collate_fn(examples,image_token_id,model,processor):
    instances = []
    for example in examples:
        prompt = example["generated_prompt"]

        user_content = [{"type": "text", "text": "Caption the video."}]
        user_content.append({"type": "video", "path": example["blur_full_video_path"]})

        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": [{"type": "text", "text": f"{prompt}"}]}
        ]

        instance = processor.apply_chat_template(messages, add_generation_prompt=False,
                                                 tokenize=True, return_dict=True, return_tensors="pt").to("cuda").to(model.dtype)
        instances.append(instance)


    input_ids = pad_sequence(
        [inst["input_ids"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id
    )
    attention_mask = pad_sequence(
        [inst["attention_mask"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=0
    )
    labels = pad_sequence(
        [inst["input_ids"].squeeze(0).clone() for inst in instances],
        batch_first=True,
        padding_value=-100
    )

    labels[labels == image_token_id] = -100

    out = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


    # Step 1: figure out maximum frames, height, width across the batch
    pvs = [inst["pixel_values"].squeeze(0) for inst in instances if "pixel_values" in inst]
    if pvs:  # there is at least one non-None pixel_values
        max_frames = max(pv.shape[0] for pv in pvs)
        max_h = max(pv.shape[-2] for pv in pvs)
        max_w = max(pv.shape[-1] for pv in pvs)
    else:
        max_h = max_w = processor.video_size['longest_edge']
        max_frames = 1

    padded_pixel_values_list = []
    for ex in instances:
        pv = ex.get("pixel_values", None).squeeze(0)

        if pv is None:
            # text-only => fill pixel data + mask with zeros
            shape_pv = (max_frames, 3, max_h, max_w)
            padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
        else:
            f, c, h, w = pv.shape
            # Prepare final storage
            padded_pv = torch.zeros(
                (max_frames, c, max_h, max_w),
                dtype=pv.dtype,
                device=pv.device
            )
            padded_pv[:f, :, :h, :w] = pv
        padded_pixel_values_list.append(padded_pv)

    out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
    return out





def collate_fn_blur(examples,image_token_id,model,processor, withBlur, extended, num_frames):
    videoPath = "origin_video_path" if withBlur==False else "blur_full_video_path"
    instances = []
    for example in examples:
        
        dummy_input = [{"type": "text", "text": "Caption the video."}]
        dummy_input.append({"type": "video", "path": example[videoPath]})
        dummy_input = [
            {"role": "user", "content": dummy_input},
            {"role": "assistant", "content": [{"type": "text", "text": f""}]}
        ]

        _,indices =  processor.apply_chat_template(dummy_input, add_generation_prompt=False, extended = extended, num_frames = num_frames,
                                                 tokenize=True, return_dict=True, return_tensors="pt", return_frame_indices=True)
        
        prompt = ast.literal_eval(example["bbox"]) 
        prompt = np.array(prompt)
        indices[-1] = min(indices[-1], len(prompt)-1) 
        prompt = prompt[indices]
        prompt = ';'.join(prompt.tolist())


        user_content = [{"type": "text", "text": "Return xyxy coordinates for the object in the video"}]
        user_content.append({"type": "video", "path": example[videoPath]})

        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": [{"type": "text", "text": f"{prompt}"}]}
        ]

        instance = processor.apply_chat_template(messages, add_generation_prompt=False, extended = extended, num_frames = num_frames,
                                                 tokenize=True, return_dict=True, return_tensors="pt").to("cuda").to(model.dtype)
        
        instances.append(instance)


    input_ids = pad_sequence(
        [inst["input_ids"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id
    )
    attention_mask = pad_sequence(
        [inst["attention_mask"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=0
    )
    labels = pad_sequence(
        [inst["input_ids"].squeeze(0).clone() for inst in instances],
        batch_first=True,
        padding_value=-100
    )

    labels[labels == image_token_id] = -100

    out = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


    # Step 1: figure out maximum frames, height, width across the batch
    pvs = [inst["pixel_values"].squeeze(0) for inst in instances if "pixel_values" in inst]
    if pvs:  # there is at least one non-None pixel_values
        max_frames = max(pv.shape[0] for pv in pvs)
        max_h = max(pv.shape[-2] for pv in pvs)
        max_w = max(pv.shape[-1] for pv in pvs)
    else:
        max_h = max_w = processor.video_size['longest_edge']
        max_frames = 1

    padded_pixel_values_list = []
    for ex in instances:
        pv = ex.get("pixel_values", None).squeeze(0)

        if pv is None:
            # text-only => fill pixel data + mask with zeros
            shape_pv = (max_frames, 3, max_h, max_w)
            padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
        else:
            f, c, h, w = pv.shape
            # Prepare final storage
            padded_pv = torch.zeros(
                (max_frames, c, max_h, max_w),
                dtype=pv.dtype,
                device=pv.device
            )
            padded_pv[:f, :, :h, :w] = pv
        padded_pixel_values_list.append(padded_pv)

    out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
    return out






def collate_fn_optFlow(examples,image_token_id,model,processor, withBlur, extended, num_frames, coarse, is_joint_learning,foc,foc_hard):
    videoPath = "origin_video_path" if withBlur==False else "blur_full_video_path"
    instances = []
    B_diffs   = []
    B_vis   = []
    B_inner_mask   = []



    for example in examples:

        #print(f"\t PATH {example[videoPath]}")
        #example[videoPath] = "dataset/got10k/teacher/train/uniform_blur/GOT-10k_Train_000966/video_original.mp4"
        
        dummy_input = [{"type": "text", "text": "Caption the video."}]
        dummy_input.append({"type": "video", "path": example[videoPath]})
        dummy_input = [
            {"role": "user", "content": dummy_input},
            {"role": "assistant", "content": [{"type": "text", "text": f""}]}
        ]

        dummy_instance,indices =  processor.apply_chat_template(dummy_input, add_generation_prompt=False, extended = extended, num_frames = num_frames,
                                                 tokenize=True, return_dict=True, return_tensors="pt", return_frame_indices=True)
        
        
        print(processor.batch_decode(dummy_instance["input_ids"]))
        #print(f"\t indices.shape {indices.shape}")
        
        
        target_H = dummy_instance["pixel_values"].shape[-2]
        target_W = dummy_instance["pixel_values"].shape[-1]
        target_T = dummy_instance["pixel_values"].shape[1]
        
        #print(dummy_instance["pixel_values"].shape)
        prompt = ast.literal_eval(example["bbox"]) 


        prompt = np.array(prompt)
        indices[-1] = min(indices[-1], len(prompt)-1)
        if indices.shape[0] < target_T:
            indices = np.append(indices, indices[-1])
      
            
        prompt = prompt[indices]
        bboxes = []
        valid = True
        for i in range(prompt.shape[0]):
            numbers = re.findall(r'<loc(\d+)>', prompt[i])
            if len(numbers)!=4:
                bboxes.append([-1, -1, -1, -1])
            else:
                y_min, x_min, y_max, x_max = map(int, numbers)
                x_min *= (target_W / 1024)
                y_min *= (target_H / 1024)
                x_max *= (target_W / 1024)
                y_max *= (target_H / 1024)
    
                x_min =  int(x_min)
                y_min =  int(y_min)
                x_max =  int(x_max)
                y_max =  int(y_max)
                bboxes.append([x_min, y_min, x_max, y_max])

        #if valid == False:
        #    continue

        #print(f"\t bboxes len {len(bboxes)}")

        bboxes = np.array(bboxes).astype(np.int32)
        inner_means, outer_means, diffs, pred_visibility, inner_mask = collect_difference_vectors(
            dummy_instance["pixel_values"], 
            example[videoPath], 
            bboxes, 
            indices,
            target_W,target_H, 
            modified=True,foc = foc, foc_hard = foc_hard)
    
        #print(f"\t len(inner_means) {len(inner_means)}")
        #print(f"\t len(diffs) {diffs.shape}")


        B_diffs.append(diffs)
        B_vis.append(pred_visibility)
        B_inner_mask.append(inner_mask)


        if is_joint_learning:
            prompt = example["generated_prompt"]
            user_text_learning = "Caption the video."
            
        else:
            user_text_learning = "Return the 2D movement vectors (dx, dy) for the object and for the camera, for every frame in the video."
            prompt = ""
            for vec_idx in range(len(inner_means)):
                for time in range(2):
                    lst = inner_means if time ==0 else outer_means
                    pre = "obj" if time==0 else "cam"

                    val_x, val_y = lst[vec_idx]
                    val_x = round(val_x.item())
                    val_y = round(val_y.item())

                    sign_char_x = '+' if val_x >= 0 else '-'
                    sign_char_y = '+' if val_y >= 0 else '-'

                    if coarse == False:
                        val_x = min(abs(val_x),512)
                        val_y = min(abs(val_y),512)
                        prompt+=f"<{pre}dx{sign_char_x}{val_x:03d}><{pre}dy{sign_char_y}{val_y:03d}>"
                    else:
                        val_x = min(abs(val_x),1)
                        val_y = min(abs(val_y),1)
                        prompt+=f"<{pre}dx{sign_char_x}{val_x:01d}><{pre}dy{sign_char_y}{val_y:01d}>"


                if vec_idx < len(inner_means)-1:
                    prompt+=";"

      
        
        #user_content = [{"type": "text", "text": "Return xyxy coordinates for the object in the video"}]
        
        user_content = [{"type": "text", "text": user_text_learning}]


        user_content.append({"type": "video", "path": example[videoPath]})

        messages = [
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": [{"type": "text", "text": f"{prompt}"}]}
        ]

        instance = processor.apply_chat_template(messages, add_generation_prompt=False, extended = extended, num_frames = num_frames,
                                                 tokenize=True, return_dict=True, return_tensors="pt").to("cuda").to(model.dtype)
        
        if instance["pixel_values"].shape[1] < num_frames:
            
            last_frame = instance["pixel_values"][:, -1:, ...]         # shape: [1, 1, 3, 384, 384]
            instance["pixel_values"] = torch.cat([instance["pixel_values"], last_frame], dim=1)  # shape: [1, N, 3, 384, 384]
        instances.append(instance)
        


    movement_vectors = torch.cat(B_diffs, dim=0)
    pred_visibility = torch.cat(B_vis, dim=0)
    inner_mask = torch.cat(B_inner_mask, dim=0)
    
    
    input_ids = pad_sequence(
        [inst["input_ids"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=processor.tokenizer.pad_token_id
    )
    attention_mask = pad_sequence(
        [inst["attention_mask"].squeeze(0) for inst in instances],
        batch_first=True,
        padding_value=0
    )
    labels = pad_sequence(
        [inst["input_ids"].squeeze(0).clone() for inst in instances],
        batch_first=True,
        padding_value=-100
    )

    labels[labels == image_token_id] = -100

    out = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "movement_vectors": movement_vectors,
        'pred_visibility': pred_visibility,
        "inner_mask": inner_mask

    }
   
    # Step 1: figure out maximum frames, height, width across the batch
    pvs = [inst["pixel_values"].squeeze(0) for inst in instances if "pixel_values" in inst]
    if pvs:  # there is at least one non-None pixel_values
        max_frames = max(pv.shape[0] for pv in pvs)
        max_h = max(pv.shape[-2] for pv in pvs)
        max_w = max(pv.shape[-1] for pv in pvs)
    else:
        max_h = max_w = processor.video_size['longest_edge']
        max_frames = 1

    padded_pixel_values_list = []
    for ex in instances:
        pv = ex.get("pixel_values", None).squeeze(0)

        if pv is None:
            # text-only => fill pixel data + mask with zeros
            shape_pv = (max_frames, 3, max_h, max_w)
            padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
        else:
            f, c, h, w = pv.shape
            # Prepare final storage
            padded_pv = torch.zeros(
                (max_frames, c, max_h, max_w),
                dtype=pv.dtype,
                device=pv.device
            )
            padded_pv[:f, :, :h, :w] = pv
        padded_pixel_values_list.append(padded_pv)

    out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
    return out





def basic_distillation(args):
    
    model_id = args.model_id
    processor_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" if "500M" in args.model_size  else "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
    quant    = True if "quant" in  args.model_size else False 
    USE_QLORA = False

    if quant:
        lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
        use_dora=False if USE_QLORA else True,
        init_lora_weights="gaussian"
        )
        lora_config.inference_mode = False
        if USE_QLORA:
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )

        model = SmolVLMForConditionalGeneration.from_pretrained(
            model_id,
            quantization_config=bnb_config if USE_QLORA else None,
            device_map="auto"
        )
        model.add_adapter(lora_config)
        model.enable_adapters()
        model = prepare_model_for_kbit_training(model)
        model = get_peft_model(model, lora_config)
        print(model.get_nb_trainable_parameters())
    else:
        processor =  SmolVLMProcessor.from_pretrained(processor_id)

        if "joint_learning" in args.mode:
            model = model_jl.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            use_mask = "mask" in args.mode,
            use_emph = ("mask" in args.mode) and ("emph" in args.mode),
            freeze   = ("freeze" in args.mode),
            foc   = ("foc" in args.mode),
            l2    = ("l2" in args.mode)

       
            #_attn_implementation="flash_attention_2",
        ).to("cuda")
        else:
            model = SmolVLMForConditionalGeneration.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16,
                #_attn_implementation="flash_attention_2",
            ).to("cuda")


        
        if 'freeze' in args.mode and "joint_learning" in args.mode:
             for name, param in model.named_parameters():
                 if 'movement_encoder' not in name and 'MCA_layers' not in name:
                    param.requires_grad = False
        
        
        
        for name, param in model.model.vision_model.named_parameters():
            
           
            if 'freeze' in args.mode:
                if 'movement_encoder' not in name and 'MCA_layers' not in name:
                    param.requires_grad = False
            else:
                if "joint_learning" in args.mode:
                    if 'encoder.layers' not in name and 'movement_encoder' not in name and 'MCA_layers' not in name:
                        param.requires_grad = False
                else:
                    if 'vision_model.encoder.layers' not in name and 'movement_encoder' not in name and 'MCA_layers' not in name:
                        param.requires_grad = False

            if args.resume == False and name == "encoder.MCA_layers.3.transformer.out_proj.weight":
                param.data.zero_()
            
            #if "joint_learning" in args.mode and param.requires_grad == True:
            #    param.register_hook(lambda g, n=name: print(n, g.norm().item()))

        #if "joint_learning" in args.mode:
        #    for name, param in model.model.text_model.named_parameters():
        #        if 'head' not in name:
        #            param.requires_grad = False
        #    numbers = range(23)
            #for name, param in model.model.vision_model.named_parameters():
            #    for num in numbers:
            #        if f'encoder.layers.{num}' in name:
            #            param.requires_grad = False

        #for param in model.model.vision_model.parameters():
        #    param.requires_grad = False
      
    print("\n\n")
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            print(name)
    
   

    dataset = load_dataset("csv", data_files=args.dataset_csv)["train"]
    #if "joint_learning" in args.mode:
    #    dataset = dataset.shuffle(seed=42).select(range(200))


    #print(dataset)
    #print(f"prompt:  {dataset[0]['generated_prompt']}, video: {dataset['origin_video_path']}")
    image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
    ]

    num_train_epochs=10
   
    b_size = 8 if "joint_learning" in args.mode else 16

    
    
    training_args = TrainingArguments(
        num_train_epochs=num_train_epochs,                  #5
        per_device_train_batch_size=b_size, #16
        gradient_accumulation_steps=1,
        warmup_steps=50,
        learning_rate=1e-4,
        weight_decay=0.01,
        logging_steps=25,
        save_strategy="steps",
        save_steps=300,
        save_total_limit=3,
        optim="adamw_torch" if quant==False else "paged_adamw_8bit", # for 8-bit, keep paged_adamw_8bit, else adamw_hf
        bf16=True,
        #resume_from_checkpoint=True,
        report_to="tensorboard",
        dataloader_drop_last = True,

        output_dir=f"./{args.save_dir}",
        hub_model_id=f"./{args.save_dir}",
        logging_dir=f"./{args.save_dir}/logs", 
        remove_unused_columns=False,
        gradient_checkpointing=True,
        dataloader_pin_memory=False,
        
        
        #batch_eval_metrics = True,
        #per_device_eval_batch_size=8,
        #evaluation_strategy="steps",
        #eval_steps=301,
        #eval_on_start =True,



    )
    #resume_from_checkpoint=True
    withExtended = True
    num_frames = 30 if "500M" in args.model_size else 30
        
     

    if "BB" in args.mode:
        data_collator_fn = partial(collate_fn_blur, image_token_id=image_token_id, model=model, processor=processor, withBlur = False, extended=withExtended, num_frames=num_frames)  
    else:
        data_collator_fn = partial(collate_fn_optFlow, image_token_id=image_token_id, model=model, processor=processor, withBlur = False, extended=withExtended, num_frames=num_frames, 
                                   coarse= "coarse" in args.mode, 
                                   is_joint_learning = "joint_learning" in args.mode,
                                   foc = "foc" in args.mode,
                                   foc_hard = "hard" in args.mode)  

 
    if "joint_learning" not in args.mode:
        trainer = Trainer(
        model=model,
        args=training_args,
        data_collator= data_collator_fn,
        train_dataset=dataset,
        callbacks=[LossLoggerCallback(f"./{args.save_dir}/logs/log.txt")] 
        )
    else:
        #compute_metrics_fn = partial(compute_custom_metrics,compute_result=True, tokenizer=processor)

        trainer = CustomTrainer(
        model=model,
        args=training_args,
        data_collator= data_collator_fn,
        train_dataset=dataset,
        #eval_dataset=dataset,
        #compute_metrics = compute_metrics_fn,
        callbacks=[LossLoggerCallback2(f"./{args.save_dir}/logs/log.txt")] 
        )

    
    trainer.train(resume_from_checkpoint=args.resume) #resume_from_checkpoint=True



if __name__ == "__main__":
    args          = parse_args()
    config.get_config(args)
    args.dataset_csv =  'dataset/got10k/teacher/train/uniform_blur/combined_w_BB.csv'
    #if ("BB" in args.mode) or ("optFlow" in args.mode):
    #    args.dataset_csv =  'dataset/got10k/teacher/train/uniform_blur/combined_w_BB.csv'
    #else:
    #    args.dataset_csv =  'dataset/got10k/teacher/train/uniform_blur/combined.csv' #f"dataset/got10k/teacher/train/{args.mode}/combined.csv" 

    basic_distillation(args)