import argparse
from collections import deque
from typing import Optional, Sequence
import numpy as np
from PIL import Image
import csv
from absl import flags
import json
import torch.nn.functional as F

import torch
from torchvision import transforms as T, utils
import sys
import os

sys.path.append('/home/t-sye/World-Model/Phenaki')
from phenaki_pytorch import CViViT_single_linear_nsvq_3


image_size = 256
transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((image_size, image_size)),
            # T.CenterCrop(image_size),
            T.ToTensor()
        ])


# load_checkpoint "params::/home/t-sye/World-Model/checkpoints/bridge_carrot_256_batch_128_multitask_40traj/streaming_params_${checkpoint}"
# update-llama-config "dict(action_vocab_size=245,sample_mode='text',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \

# python -m lwm.latent_lwm_checking_phenaki  --load_cvivit_checkpoint "/home/t-sye/World-Model/checkpoints/phenaki/vast_ai_backup/whole_window3/ddp_gpu_2_batch224_layer8_seq4_code2/vae.20000.pt" --tokens_per_delta 4 --codebook_size 2 --latent_action_file "/home/t-sye/World-Model/data/bridge_window3_codebook2_codeseqlen4_epoch1_latent_actions.jsonl"
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="LWM Inference script")
    parser.add_argument('--codebook_size', type=int, default=2, help='Codebook size for the model')
    parser.add_argument('--tokens_per_delta', type=int, default=1, help='Tokens per delta')
    parser.add_argument('--load_cvivit_checkpoint', type=str, default="")
    parser.add_argument('--latent_action_file', type=str, default="")


    # Add more arguments as needed
    
    
    args = parser.parse_args()


    cvivit = CViViT_single_linear_nsvq_3(
        dim = 1024,
        quant_dim=32,
        codebook_size = args.codebook_size, #32,
        image_size = 256,
        patch_size = 32,
        temporal_patch_size = 2,
        spatial_depth = 8,
        temporal_depth = 8,
        dim_head = 64,
        heads = 16,
        lookup_free_quantization=False,
        use_vgg_and_gan=False,
        code_seq_len=args.tokens_per_delta,
    ).cuda()

    cvivit.load(args.load_cvivit_checkpoint)

    # load jsonl file
    json_obj = []
    # with open('/home/t-sye/World-Model/data/bridge_window3_codebook8_codeseqlen4_unshuffled_filtered_latent_actions.jsonl', 'r') as f:
    with open(args.latent_action_file, 'r') as f:
        for line in f:
            json_obj.append(json.loads(line))
    

    last_index = 100
    indices_list = []
    total_recon_loss = 0
    for index, json_elem in enumerate(json_obj):
        generated_images=[]
        if index > last_index:
            break
        
        first_image = json_elem['image'][0]
        next_image = json_elem['image'][1]
        indices = json_elem['latent_action']
        image_1 = Image.open(first_image)
        image_2 = Image.open(next_image)
        generated_images.append(image_1)
        generated_images.append(image_2)
        image = np.array(image_1)
        next_image = np.array(image_2)
        print("image shape", image.shape)

        img = transform(image_1).unsqueeze(1)
        img2 = transform(image_2).unsqueeze(1)
        img = torch.cat([img, img2], dim=1).unsqueeze(0).cuda()

        video = cvivit.inference(img, user_action_token_num=indices)
        video = torch.clamp(video, 0.0, 1.0)
        video *= 255.0
        video = video.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
        predicted = video[0]

        video = [Image.fromarray(image) for image in video]
        generated_images.append(video[0])

        total_image = Image.new('RGB', (256*3, 256))
        x_offset = 0
        for im in generated_images:
            total_image.paste(im, (x_offset,0))
            x_offset += 256
        # mkdir folder if not exist
        os.makedirs(f'data/analysis_image/window3_codebook{args.codebook_size}_codeseqlen{args.tokens_per_delta}_epoch3', exist_ok=True)
        total_image.save(f'data/analysis_image/window3_codebook{args.codebook_size}_codeseqlen{args.tokens_per_delta}_epoch3/{index}.jpg')

        # print("image", image.shape, predicted.shape)
        recon_loss = np.mean((next_image - predicted) ** 2)
        total_recon_loss += recon_loss
    print("average total_recon_loss", total_recon_loss/last_index)

# 57.56497360229494
# 58.57464584350588


    

