import torch as th
from torch import nn
from torch.utils.data import Dataset, DataLoader
import cv2

from torch.nn.parallel import DistributedDataParallel
import numpy as np
import os
from typing import Tuple, Union, List
from einops import rearrange, repeat, reduce
from pathlib import Path

from utils.configuration import Configuration
from utils.parallel import run_parallel, DatasetPartition
from utils.loss import depth_smooth_loss, UncertaintyGANLoss
from utils.io import model_path
from utils.optimizers import SDAdam, SDAMSGrad, RAdam, Ranger
import torch.distributed as dist
import time
import random
from utils.io import Timer, BinaryStatistics, UEMA, SequenceToImgs
from utils.data import DeviceSideDataset
from nn.background import ViTDepthUncertantyBackground
from einops import rearrange, repeat, reduce
import torch.nn.functional as F

def interpolate(tensor):
     return F.interpolate(tensor.unsqueeze(dim=0), size=(64,64), mode='bilinear', align_corners=False)[0]

def run(cfg: Configuration, trainset: Dataset, valset: Dataset, testset: Dataset, file):

    device = th.device(cfg.device)

    cfg_net = cfg.model

    net = ViTDepthUncertantyBackground(
        latent_size               = cfg_net.latent_size,
        reg_lambda                = cfg_net.background.reg_lambda,
        batch_size                = cfg_net.batch_size,
        hidden_channels           = cfg_net.background.channels,
        num_embedding_layers      = cfg_net.background.num_embedding_layers,
        num_attention_layers      = cfg_net.background.num_attention_layers,
        num_hyper_layers          = cfg_net.background.num_hyper_layers,
        hyper_channels            = cfg_net.background.num_hyper_channels,
        num_heads                 = cfg_net.background.num_heads,
        uncertainty_base_channels = cfg_net.background.uncertainty_base_channels,
        uncertainty_blocks        = cfg_net.background.uncertainty_blocks,
        entity_pretraining_steps  = cfg_net.entity_pretraining_steps,
        uncertainty_threshold     = cfg_net.background.uncertainty_threshold,
        depth_input               = cfg_net.background.depth_input
    )

    net = net.to(device=device)

    trainloader = DataLoader(
        trainset, 
        pin_memory = True, 
        num_workers = cfg.num_workers, 
        batch_size = cfg_net.batch_size, 
        shuffle = False,
        drop_last = False, 
        prefetch_factor = cfg.prefetch_factor, 
    )

    testloader = DataLoader(
        testset, 
        pin_memory = True, 
        num_workers = cfg.num_workers, 
        batch_size = cfg_net.batch_size, 
        shuffle = False,
        drop_last = False, 
        prefetch_factor = cfg.prefetch_factor, 
    )

    valloader = DataLoader(
        valset, 
        pin_memory = True, 
        num_workers = cfg.num_workers, 
        batch_size = cfg_net.batch_size, 
        shuffle = False,
        drop_last = False, 
        prefetch_factor = cfg.prefetch_factor, 
    )

    if file != "":
        state = th.load(file)
        net.load_state_dict(state['model'])
        print(f'loaded[{device}] {file}', flush=True)

    th.backends.cudnn.benchmark = True
    timer = Timer()

    root_path = "/media/chief/data/movi-e-bg"

    with th.no_grad():
        for dataloader, type  in [(trainloader, 'train'), (testloader, 'test'), (valloader, 'validation')]:
            samples_index = 0
            for batch_index, input in enumerate(dataloader):

                tensor = input[0]
                depth  = input[1] 

                sequence_len = (tensor.shape[1] - 1) 

                net.reset_state()
                depth_next = depth[:,0].to(device)
                rgb_next   = tensor[:,0].to(device)

                uncertainty_cur  = None
                uncertainty_next = net.uncertainty_estimation(th.cat((depth_next, rgb_next), dim=1))[0]

                net(th.cat((rgb_next, depth_next), dim=1), uncertainty_next)
                net(th.cat((rgb_next, depth_next), dim=1), uncertainty_next)
                output_rgb_next, output_depth_next = net(th.cat((rgb_next, depth_next), dim=1), uncertainty_next)
                    
                output_rgb   = [output_rgb_next.cpu()]
                output_depth = [output_depth_next.cpu()]
                uncertainty  = [uncertainty_next.cpu()]

                for t in range(sequence_len):
                    rgb_cur         = rgb_next
                    rgb_next        = tensor[:,t+1].to(device)
                    depth_cur       = depth_next
                    depth_next      = depth[:,t+1].to(device)
                    uncertainty_cur = uncertainty_next
                    
                    output_rgb_next, output_depth_next = net(th.cat((rgb_cur, depth_cur), dim=1), uncertainty_cur)
                    uncertainty_next = net.uncertainty_estimation(th.cat((depth_next, rgb_next), dim=1))[0]

                    output_rgb.append(output_rgb_next.cpu())
                    output_depth.append(output_depth_next.cpu())
                    uncertainty.append(uncertainty_next.cpu())

                for i in range(tensor.shape[0]):
                    dst_path256 = os.path.join(root_path, type, f'{samples_index:06d}', '256x256')
                    dst_path64 = os.path.join(root_path, type, f'{samples_index:06d}', '64x64')
                    os.makedirs(dst_path256, exist_ok = True) 
                    os.makedirs(dst_path64, exist_ok = True) 
                    samples_index += 1

                    for t in range(len(uncertainty)):

                        cv2.imwrite(os.path.join(dst_path256, f'frame{t:03d}.jpg'), tensor[i,t].numpy().transpose(1, 2, 0)*255)
                        cv2.imwrite(os.path.join(dst_path256, f'depth{t:03d}.jpg'), depth[i,t,0].numpy()*255)
                        cv2.imwrite(os.path.join(dst_path256, f'bg-rgb{t:03d}.jpg'), output_rgb[t][i].numpy().transpose(1, 2, 0)*255)
                        cv2.imwrite(os.path.join(dst_path256, f'bg-depth{t:03d}.jpg'), output_depth[t][i,0].numpy()*255)
                        cv2.imwrite(os.path.join(dst_path256, f'uncertainty{t:03d}.jpg'), uncertainty[t][i,0].numpy()*255)

                        cv2.imwrite(os.path.join(dst_path64, f'frame{t:03d}.jpg'), interpolate(tensor[i,t]).numpy().transpose(1, 2, 0)*255)
                        cv2.imwrite(os.path.join(dst_path64, f'depth{t:03d}.jpg'), interpolate(depth[i,t])[0].numpy()*255)
                        cv2.imwrite(os.path.join(dst_path64, f'bg-rgb{t:03d}.jpg'), interpolate(output_rgb[t][i]).numpy().transpose(1, 2, 0)*255)
                        cv2.imwrite(os.path.join(dst_path64, f'bg-depth{t:03d}.jpg'), interpolate(output_depth[t][i])[0].numpy()*255)
                        cv2.imwrite(os.path.join(dst_path64, f'uncertainty{t:03d}.jpg'), interpolate(uncertainty[t][i])[0].numpy()*255)



                print(f"Exporting[{type}|{batch_index * 100 / len(dataloader):.2f}%|{batch_index}|{len(dataloader)}] {str(timer)}")
