import sys
sys.path.append('.')

import os.path as osp
from pathlib import Path
import warnings
warnings.simplefilter("once", category=FutureWarning)

import torch
from einops import rearrange

from lib.trainers.canny_aug_trainer import CannyAugTrainer
from lib.unified_dataset_wm.dataset.ac_dataset import MixAC
from lib.utils.torch_utils import seed_everything, save_video
from lib.utils.misc import ProgressTracker
from lib.utils.memory_utils import free_memory
from img_edit.flux_pipeline import prepare_relighter

def prepare_model_and_data():
    # NOTE: Replace config and data paths with generic placeholders for anonymization
    config_file = '<CONFIG_FILE_PATH>'
    trainer = CannyAugTrainer(config_file, val_only=True)

    # model and path
    root_dir = '<EXPERIMENT_ROOT_DIR>'
    model_dir = osp.join(root_dir, '<EXPERIMENT_TIMESTAMP>/<MODEL_STEP>')
    model_path = osp.join(model_dir, 'diffusion_pytorch_model.safetensors')
    trainer.args.transformer['model_path'] = model_path

    # video clip params
    n_prev = 4
    chunk_size = 73  # must equal 4n+1

    # Example dataset configuration (anonymized)
    trainer.args.data['val']['jsonl_path_list'] = ['<VAL_JSONL_PATH>']
    trainer.args.data['val']['video_folder_list'] = ['<VAL_VIDEO_ROOT>']
    trainer.args.data['val']['dataset_name_list'] = ['<VAL_DATASET_NAME>']
    trainer.args.data['val']['dataset_source_list'] = ['<VAL_DATASET_SOURCE>']
    trainer.args.data['val']['cam_use'] = [['head', 'hand_left', 'hand_right']]
    save_dir = '<RESULT_SAVE_ROOT>'

    Path(save_dir).mkdir(exist_ok=True)
    trainer.args.output_dir = save_dir
    trainer.save_folder = save_dir

    # data 
    trainer.args.data['train']['sample_all_frames'] = True
    trainer.args.data['val']['sample_all_frames'] = True
    trainer.args.data['train']['chunk'] = chunk_size
    trainer.args.data['val']['chunk'] = chunk_size
    trainer.args.data['train']['action_chunk'] = chunk_size
    trainer.args.data['val']['action_chunk'] = chunk_size
    trainer.args.data['val']['ignore_seek'] = False
    trainer.args.wo_hand_cond = False
    fps = 30
    trainer.args.data['train']['fps'] = fps
    trainer.args.data['val']['fps'] = fps

    trainer.val_dataset = MixAC(**trainer.args.data['val'])
    trainer.val_dataloader = torch.utils.data.DataLoader(trainer.val_dataset, batch_size=1, shuffle=False)

    SEP = 1

    # prepare models
    trainer.prepare_models()

    accelerator = trainer.state.accelerator
    return trainer, chunk_size, n_prev, SEP, accelerator, save_dir, fps

@torch.no_grad()
def main_chunk_wise(
    merge_view_into_width=False, 
    save_gt=False, 
    relight_type='relight_0', steps=4,
    relight_candidates=None,
    multiprocess=False
):
    seed_everything(42)

    trainer, chunk_size, n_prev, SEP, accelerator, save_dir, fps = prepare_model_and_data()
    device = accelerator.device
    local_rank = device.index

    if multiprocess:
        assert relight_candidates is not None
        assert len(relight_candidates) <= 8
        local_rank = int(local_rank)
        relight_type = relight_candidates[local_rank]
        print(f'[INFO] {relight_type=} for local_rank {local_rank}')

    @accelerator.on_local_main_process
    def print_on_master(msg):
        print(msg)

    save_dir = osp.join(save_dir, relight_type)
    Path(save_dir).mkdir(exist_ok=True)
    print(f'[INFO] save_dir updated to {save_dir}')

    dataset = trainer.val_dataloader.dataset
    for idx in range(len(dataset)):
        data = dataset[idx]
        prompt = [data['caption']]
        video = data['video'].to('cpu').unsqueeze(0)
        video_path = [data['path']]

        b, c, v, t_all, h, w = video.shape

        if chunk_size < t_all:
            all_chunks = video.split(chunk_size, dim=3)
        else:
            all_chunks = [video]

        results = None
        gt = None
        tracker = ProgressTracker(len(all_chunks), description=f'sample_{idx} chunk loop')
        for video_chunk in all_chunks:
            if video_chunk.shape[3] != chunk_size:
                repeat_num = chunk_size - video_chunk.shape[3]
                residule = [video_chunk[:, :, :, -1:]] * repeat_num
                video_chunk = torch.cat([video_chunk] + residule, dim=3)
                assert video_chunk.shape[3] == chunk_size

            if results is None:  # first iteration, edit frame_0
                trainer.text_encoder = trainer.text_encoder.to('cpu')
                trainer.vae = trainer.vae.to('cpu')
                trainer.transformer = trainer.transformer.to('cpu')
                free_memory()
                relighter = prepare_relighter(relight_type, device)
                relighted_img = relighter.generate(video_chunk[:, :, 0, 0].to(device), enable_tqdm=False)
                del relighter
                free_memory()
                trainer.text_encoder = trainer.text_encoder.to(device)
                trainer.vae = trainer.vae.to(device)
                trainer.transformer = trainer.transformer.to(device)
                relighted = video_chunk.clone().to(device)
                relighted[:, :, 0, 0] = relighted_img.to(video_chunk)
                video_for_canny = video_chunk.clone().to(device)
                relighted = torch.cat([relighted[:, :, :, :1]] * n_prev + [relighted], dim=3)
                video_for_canny = torch.cat([video_for_canny[:, :, :, :1]] * n_prev + [video_for_canny], dim=3)
                tracker.start()
            else:
                mem_index = torch.linspace(0, results.shape[2]-1, n_prev).long()
                mem = results[:, :, mem_index].to(video_chunk)
                mem = rearrange(mem, '(b v) c t h w -> b c v t h w', v=3)
                relighted = torch.cat([mem, video_chunk], dim=3).to(device)
                mem_gt = gt[:, :, :, mem_index].to(video_chunk)
                video_for_canny = torch.cat([mem_gt, video_chunk], dim=3).to(device)

            preds = trainer.validate(
                accelerator, save_dir,
                video=relighted, video_for_canny=video_for_canny,
                prompt=prompt, n_prev=n_prev, n_view=3,
                chunk_size=video_chunk.shape[3],
                merge_view_into_width=False, fps=trainer.args.data['val']['fps'],
                video_path=video_path, vis_cat_gt=False,
                write_video_to_disk=False,
                guidance_scale=1.0, pipeline_progress=False,
                num_inference_steps=steps
            )
            pred_video = preds.detach().cpu()
            assert pred_video.shape[2] == chunk_size
            if results is None:
                results = pred_video
                gt = video_chunk.detach().cpu()
            else:
                results = torch.cat([results, pred_video], dim=2)
                gt = torch.cat([gt, video_chunk.cpu()], dim=3)

            tracker.update()
            print_on_master(tracker.get_progress_string())
            del preds
            free_memory()

        if merge_view_into_width:
            results = rearrange(results, '(b v) c t h w -> b c t h (v w)', v=3)
            save_name = osp.join(save_dir, video_path[0].split('/')[-1] + f'_steps{steps}.mp4')
            save_video(results[0], save_name)
            print(f'Result saved to {save_name}')

            if save_gt:
                gt = rearrange(gt, 'b c v t h w -> b c t h (v w)')
                save_name = osp.join(save_dir, video_path[0].split('/')[-1] + '_gt.mp4')
                save_video(gt[0], save_name)
                print(f'Result saved to {save_name}')
        else:
            video_path = video_path[0]
            # Example of anonymized path routing logic
            if '<VAL_VIDEO_ROOT>' in video_path:
                episode_dir = osp.join(save_dir, video_path.split('/')[-1])
                Path(episode_dir).mkdir(exist_ok=True)
                view_to_video_map = ('view1.mp4', 'view2.mp4', 'view3.mp4')
                for iv in range(len(results)):
                    save_name = osp.join(episode_dir, view_to_video_map[iv])
                    save_video(results[iv], save_name, fps=fps)
                    print(f'Result saved to {save_name}')
            else:
                raise NotImplementedError

if __name__ == '__main__':
    import fire
    fire.Fire(main_chunk_wise)