import os
import numpy as np
import torch
import wandb
import argparse
import torch.nn as nn
from tqdm import tqdm
from datetime import datetime
import subprocess
import multiprocessing
import cv2

from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

from idm.eval_dataset import EvalDataSet
from idm.idm import *
from idm.preprocessor import DinoPreprocessor
from idm.utils import seed_torch


def parse_args():
    parser = argparse.ArgumentParser(description="Eval IDM")
    parser.add_argument("--load_from", type=str, default=None, help="Load from path")
    parser.add_argument("--wandb_mode", type=str, default="online", help="Wandb mode")
    parser.add_argument("--use_transform", action="store_true", default=False, help="Use transform")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU")
    parser.add_argument("--eval_batch_size", type=int, default=32, help="Batch size per GPU")
    parser.add_argument("--num_workers", type=int, default=16, help="Number of data loading workers")
    parser.add_argument("--prefetch_factor", type=int, default=4, help="Number of batches to prefetch")
    parser.add_argument("--dataset_path", type=str, default="", help="Path of the dataset")
    parser.add_argument("--run_name", type=str, default=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), help="Run name")
    parser.add_argument("--save_dir", type=str, default="output", help="Save dir")
    parser.add_argument("--model_name", type=str, default="mask", help="Choose a suitable model.")
    parser.add_argument("--use_normalization", action="store_true", default=False, help="Use mean/std normalization")
    args = parser.parse_args()
    return args


def collate_fn(batch):
    return torch.stack(batch[0][0], dim=0), batch[0][1]


def get_data_generator(dataloader):
    while True:
        for data in dataloader:
            yield data


def save_model(accelerator: Accelerator, net: torch.nn.Module, optimizer: torch.optim.Optimizer, step, save_path):
    accelerator.wait_for_everyone()
    save_dir = os.path.dirname(save_path)
    if accelerator.is_main_process:
        try:
            os.makedirs(save_dir, exist_ok=True)
            if not os.access(save_dir, os.W_OK):
                print(f"Warning: No write permission for directory {save_dir}")
                return

            state_dict = {
                "model_state_dict": accelerator.unwrap_model(net).state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "step": step
            }
            torch.save(state_dict, save_path)
        except Exception as e:
            print(f"Error saving model: {str(e)}")
    accelerator.wait_for_everyone()


def save_ffmpeg(images, save_path):
    video_ffmpeg = subprocess.Popen(
        [
            "ffmpeg",
            "-y",
            "-loglevel",
            "error",
            "-f",
            "rawvideo",
            "-pixel_format",
            "rgb24",
            "-video_size",
            "640x720",
            "-framerate",
            "30",
            "-i",
            "-",
            "-pix_fmt",
            "yuv420p",
            "-vcodec",
            "libx264",
            "-crf",
            "23",
            f"{save_path}",
        ],
        stdin=subprocess.PIPE,
    )
    for image in images:
        image = cv2.resize(image, (640, 720))
        image = np.stack((image, image, image), axis=-1)
        video_ffmpeg.stdin.write(image.tobytes())
    video_ffmpeg.stdin.close()
    video_ffmpeg.wait()
    del video_ffmpeg


def eval(accelerator: Accelerator, net: torch.nn.Module, dataloader: DataLoader, save_dir='output'):
    os.makedirs(save_dir, exist_ok=True)
    accelerator.wait_for_everyone()
    net.eval()
    processes = []
    with torch.no_grad():
        for images, video_path in tqdm(dataloader, disable=not accelerator.is_main_process):
            bs = 16
            current_idx = 0
            mask = None
            while current_idx < len(images):
                if current_idx + bs * accelerator.num_processes > len(images):
                    output = net(images[current_idx:], return_mask=True)[1]
                    current_idx = len(images)
                else:
                    output = accelerator.gather(net(images[current_idx + bs * accelerator.process_index: current_idx + bs * (accelerator.process_index + 1)], return_mask=True)[1])
                    current_idx += bs * accelerator.num_processes
                if mask is None:
                    mask = output
                else:
                    mask = torch.cat((mask, output), dim=0)

            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                mask = mask.detach().cpu().numpy()
                mask = 255 * np.transpose(mask, (0, 2, 3, 1))
                mask = mask.astype(np.uint8)
                splited_video_path = video_path.split("/")
                file_name = splited_video_path[-3] + "_" + splited_video_path[-2] + "_" + splited_video_path[-1].split("_")[-1]
                process = multiprocessing.Process(target=save_ffmpeg, args=(mask, os.path.join(save_dir, file_name)))
                process.start()
                processes.append(process)
            accelerator.wait_for_everyone()

            if len(processes) > 20:
                for process in processes:
                    process.join()
                processes = []
            accelerator.wait_for_everyone()
    for process in processes:
        process.join()
    accelerator.wait_for_everyone()


def main(args):
    seed_torch(1234)
    accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
    # num_gpus = torch.cuda.device_count()
    save_dir = args.save_dir
    
    if accelerator.is_main_process:
        print(f"{args.__dict__}")

    # Initialize preprocessor
    preprocessor = DinoPreprocessor(args)
    
    # load dataset
    dataset = EvalDataSet(args, dataset_path=args.dataset_path, disable_pbar=not accelerator.is_main_process, preprocessor=preprocessor)
    dataset_size = len(dataset)
    if accelerator.is_main_process:
        print('dataset_size', dataset_size)
    
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True, collate_fn=collate_fn, drop_last=False)
    net = IDM(model_name=args.model_name, output_dim=14)
    net.eval()

    if not args.load_from or not os.path.isfile(args.load_from):
        if args.eval_only:
            raise ValueError("Must specify --load_from with a valid model path when using --eval_only")
    else:
        try:
            loaded_dict = torch.load(args.load_from, weights_only=False)
            net.load_state_dict(loaded_dict["model_state_dict"])
            if accelerator.is_main_process:
                print(f"Loaded model from {args.load_from}")
        except Exception as e:
            raise RuntimeError(f"Failed to load checkpoint from {args.load_from}: {str(e)}")

    net = accelerator.prepare(net)
    net.normalize = accelerator.unwrap_model(net).normalize

    preprocessor.use_transform = False
    eval(accelerator, net, dataloader, save_dir=save_dir)


if __name__ == "__main__":
    main(parse_args())
