import logging as log
import math
import os
from collections import defaultdict
from typing import Dict, MutableMapping, Union, Any, List
import pdb
import pandas as pd
import torch
import torch.utils.data

from datasets.kpop_hue_dataset import kpop_dataset

from utils.ema import EMA
from utils.my_tqdm import tqdm
from ops.image import metrics
from ops.image.io import write_video_to_file
from models.lowrank_model import LowrankModel
from .base_trainer import BaseTrainer, init_dloader_random, initialize_model
from .regularization import (
    PlaneTV, TimeSmoothness, HistogramLoss, L1TimePlanes, DistortionLoss,MotionLoss,BigMotionLoss
)


class KpopTrainer(BaseTrainer):
    def __init__(self,
                 tr_loader: torch.utils.data.DataLoader,
                 tr_dset: torch.utils.data.TensorDataset,
                 ts_dset: torch.utils.data.TensorDataset,
                 loss_weights: Dict,
                 num_steps: int,
                 logdir: str,
                 expname: str,
                 train_fp16: bool,
                 save_every: int,
                 valid_every: int,
                 save_outputs: bool,
                 isg_step: int,
                 ist_step: int,
                 device: Union[str, torch.device],
                 **kwargs
                 ):
        self.loss_weight = loss_weights
        print("loss_weight:", self.loss_weight)
        self.train_dataset = tr_dset
        self.test_dataset = ts_dset
        self.ist_step = ist_step
        self.isg_step = isg_step
        self.save_video = save_outputs
        #pdb.set_trace()
        # Switch to compute extra video metrics (FLIP, JOD)
        self.compute_video_metrics = False
        super().__init__(
            loss_weights=loss_weights,
            train_data_loader=tr_loader,
            num_steps=num_steps,
            logdir=logdir,
            expname=expname,
            train_fp16=train_fp16,
            save_every=save_every,
            valid_every=valid_every,
            save_outputs=False,  # False since we're saving video
            device=device,
            **kwargs)

    def eval_step(self, data, **kwargs) -> MutableMapping[str, torch.Tensor]:
        """
        Note that here `data` contains a whole image. we need to split it up before tracing
        for memory constraints.
        """
        super().eval_step(data, **kwargs)
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            timestamp = data["timestamps"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            #hues = data["hues"].to(self.device)
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b = timestamp.expand(rays_o_b.shape[0]).to(self.device)
                outputs = self.model(
                    rays_o_b, rays_d_b, timestamps=timestamps_d_b, bg_color=bg_color,
                    near_far=near_far)
                for k, v in outputs.items():
                    if "rgb" in k or "depth" in k or "class" in k or "classes" in k:
                        preds[k].append(v.cpu())
        return {k: torch.cat(v, 0) for k, v in preds.items()}

    def train_step(self, data: Dict[str, Union[int, torch.Tensor]], **kwargs):
        scale_ok = super().train_step(data, **kwargs)

        if self.global_step == self.isg_step:
            self.train_dataset.enable_isg()
            raise StopIteration  # Whenever we change the dataset
        if self.global_step == self.ist_step:
            self.train_dataset.switch_isg2ist()
            raise StopIteration  # Whenever we change the dataset

        return scale_ok

    def post_step(self, progress_bar):
        super().post_step(progress_bar)

    def pre_epoch(self):
        super().pre_epoch()
        # Reset randomness in train-dataset
        self.train_dataset.reset_iter()
    @torch.no_grad()
    def rendering_inpainting(self,data,human_time, light_time, reduction_type):
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            light_features_total = None
            count = 0 
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = human_time.expand(rays_o_b.shape[0]).to(self.device)
                timestamps_d_b_light = light_time.expand(rays_o_b.shape[0]).to(self.device)
    
                #at time human_time
                
                light_features = self.model.get_features_to_inpaint(
                    rays_o_b, rays_d_b, bg_color=bg_color,
                    near_far=near_far, time_human=timestamps_d_b_human, time_light=timestamps_d_b_light)
                if light_features is not None and count ==0:
                    light_features_total = light_features
                    count = count +1
                elif light_features is not None :
                    #print(light_features.shape)
                    #print(light_features_total.shape)
                    light_features_total = torch.concat([light_features_total,light_features],dim=0)
            if reduction_type == 'mean':
                representative_feature = torch.mean(light_features_total,dim=0,keepdim=True)
            elif reduction_type == 'median':
                representative_feature = torch.median(light_features_total,dim=0,keepdim=True)
            else:
                representative_feature = None
            #another time (light time)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = human_time.expand(rays_o_b.shape[0]).to(self.device)
                timestamps_d_b_light = light_time.expand(rays_o_b.shape[0]).to(self.device)
                #at time human_time
                out = self.model.inpainting(
                    rays_o_b, rays_d_b, bg_color=bg_color,
                    near_far=near_far,time_human=timestamps_d_b_human, time_light=timestamps_d_b_light,inpainting_features=representative_feature)

                    #self, rays_o, rays_d, bg_color, near_far, time_human, time_light, inpainting_features
                for k, v in out.items():
                    if "rgb" in k or "depth" in k :
                        preds[k].append(v.cpu())



        return {k: torch.cat(v, 0) for k, v in preds.items()}


    def rendering_no_inpaint(self,data,human_time, light_time):
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = human_time.expand(rays_o_b.shape[0]).to(self.device)
                timestamps_d_b_light = light_time.expand(rays_o_b.shape[0]).to(self.device)
    
                #at time human_time
                
                out = self.model.assemble_two_time(
                    rays_o=rays_o_b, rays_d=rays_d_b, 
                    near_far=near_far, time_human=timestamps_d_b_human, time_light=timestamps_d_b_light)
                for k, v in out.items():
                    if "rgb" in k or "depth" in k :
                        preds[k].append(v.cpu())



        return {k: torch.cat(v, 0) for k, v in preds.items()}

    def rendering_changing_hue(self,data,human_time, light_time):
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = human_time.expand(rays_o_b.shape[0]).to(self.device)
                timestamps_d_b_light = light_time.expand(rays_o_b.shape[0]).to(self.device)
    
                #at time human_time
                
                out = self.model.change_hue(
                    rays_o=rays_o_b, rays_d=rays_d_b, 
                    near_far=near_far, time_human=timestamps_d_b_human, hue_time=timestamps_d_b_light)
                for k, v in out.items():
                    if "rgb" in k or "depth" in k :
                        preds[k].append(v.cpu())



        return {k: torch.cat(v, 0) for k, v in preds.items()}
    def rendering_custom_hue(self,data,hue, fix_time):
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = fix_time.expand(rays_o_b.shape[0]).to(self.device)
                timestamps_d_b_light = fix_time.expand(rays_o_b.shape[0]).to(self.device)
    
                #at time human_time
                
                out = self.model.custom_hue(
                    rays_o=rays_o_b, rays_d=rays_d_b, 
                    near_far=near_far, time_human=timestamps_d_b_human,hue=hue)
                for k, v in out.items():
                    if "rgb" in k or "depth" in k :
                        preds[k].append(v.cpu())



        return {k: torch.cat(v, 0) for k, v in preds.items()}
    def rendering_custom_hue_vid(self,data,hue, timestamps):
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = timestamps.expand(rays_o_b.shape[0]).to(self.device)
                #timestamps_d_b_light = fix_time.expand(rays_o_b.shape[0]).to(self.device)
    
                #at time human_time
                
                out = self.model.custom_hue(
                    rays_o=rays_o_b, rays_d=rays_d_b, 
                    near_far=near_far, time_human=timestamps_d_b_human,hue=hue)
                for k, v in out.items():
                    if "rgb" in k or "depth" in k :
                        preds[k].append(v.cpu())



        return {k: torch.cat(v, 0) for k, v in preds.items()}
    def no_light(self,data,human_time):
        batch_size = self.eval_batch_size
        with torch.cuda.amp.autocast(enabled=self.train_fp16), torch.no_grad():
            rays_o = data["rays_o"]
            rays_d = data["rays_d"]
            near_far = data["near_fars"].to(self.device)
            bg_color = data["bg_color"]
            if isinstance(bg_color, torch.Tensor):
                bg_color = bg_color.to(self.device)
            preds = defaultdict(list)
            for b in range(math.ceil(rays_o.shape[0] / batch_size)):
                rays_o_b = rays_o[b * batch_size: (b + 1) * batch_size].to(self.device)
                rays_d_b = rays_d[b * batch_size: (b + 1) * batch_size].to(self.device)
                timestamps_d_b_human = human_time.expand(rays_o_b.shape[0]).to(self.device)
    
                #at time human_time
                
                out = self.model.no_light(
                    rays_o=rays_o_b, rays_d=rays_d_b, 
                    near_far=near_far, time_human=timestamps_d_b_human)
                for k, v in out.items():
                    if "rgb" in k or "depth" in k :
                        preds[k].append(v.cpu())



        return {k: torch.cat(v, 0) for k, v in preds.items()}


    @torch.no_grad()
    def validate(self):
        dataset = self.test_dataset
        per_scene_metrics: Dict[str, Union[float, List]] = defaultdict(list)
        pred_frames, out_d_frames, out_l_frames,  out_reh_frames, out_depths, out_classes,out_class_chunks = [], [], [], [], [], [], []
        pb = tqdm(total=len(dataset), desc=f"Test scene ({dataset.name})")
        for img_idx, data in enumerate(dataset):
            preds = self.eval_step(data)
            out_metrics, out_img, out_d_frame, out_l_frame, out_reh, out_depth, out_class, out_class_chunk = self.evaluate_metrics(
                data, preds, dset=dataset, img_idx=img_idx, name=None,
                save_outputs=self.save_outputs)
            pred_frames.append(out_img)
            if out_depth is not None:
                out_depths.append(out_depth)
            if out_class is not None:
                out_classes.append(out_class)
            if out_reh is not None:
                out_reh_frames.append(out_reh)
            if out_class_chunk is not None:
                out_class_chunks.append(out_class_chunk)
            if out_d_frames is not None:
                out_d_frames.append(out_d_frame)
            if out_l_frame is not None:
                out_l_frames.append(out_l_frame)

            for k, v in out_metrics.items():
                per_scene_metrics[k].append(v)
            pb.set_postfix_str(f"PSNR={out_metrics['psnr']:.2f}", refresh=False)
            pb.update(1)
        pb.close()
        if self.save_video:
            write_video_to_file(
                os.path.join(self.log_dir, f"step{self.global_step}.mp4"),
                pred_frames
            )
            write_video_to_file(
                os.path.join(self.log_dir, f"step{self.global_step}_reh.mp4"),
                out_reh_frames
            )
            if len(out_depths) > 0:
                write_video_to_file(
                    os.path.join(self.log_dir, f"step{self.global_step}-depth.mp4"),
                    out_depths
                )
            if len(out_classes) > 0:
                #db.set_trace()

                write_video_to_file(
                    os.path.join(self.log_dir, f"step{self.global_step}-class.mp4"),
                    out_classes
                )
            if len(out_class_chunks) > 0:
                #db.set_trace()

                write_video_to_file(
                    os.path.join(self.log_dir, f"step{self.global_step}-classes.mp4"),
                    out_class_chunks
                )
            if len(out_d_frames) > 0:
                #db.set_trace()

                write_video_to_file(
                    os.path.join(self.log_dir, f"step{self.global_step}-d_liv.mp4"),
                    out_d_frames
                )
            if len(out_l_frames) > 0:
                #db.set_trace()

                write_video_to_file(
                    os.path.join(self.log_dir, f"step{self.global_step}-l_liv.mp4"),
                    out_l_frames
                )

        # Calculate JOD (on whole video)
        if self.compute_video_metrics:
            per_scene_metrics["JOD"] = metrics.jod(
                [f[:dataset.img_h, :, :] for f in pred_frames],
                [f[dataset.img_h: 2*dataset.img_h, :, :] for f in pred_frames],
            )
            per_scene_metrics["FLIP"] = metrics.flip(
                [f[:dataset.img_h, :, :] for f in pred_frames],
                [f[dataset.img_h: 2*dataset.img_h, :, :] for f in pred_frames],
            )

        val_metrics = [
            self.report_test_metrics(per_scene_metrics, extra_name=None),
        ]
        df = pd.DataFrame.from_records(val_metrics)
        df.to_csv(os.path.join(self.log_dir, f"test_metrics_step{self.global_step}.csv"))

    def get_save_dict(self):
        base_save_dict = super().get_save_dict()
        return base_save_dict

    def load_model(self, checkpoint_data, training_needed: bool = True):
        super().load_model(checkpoint_data, training_needed)
        if self.train_dataset is not None:
            if -1 < self.isg_step < self.global_step < self.ist_step:
                self.train_dataset.enable_isg()
            elif -1 < self.ist_step < self.global_step:
                self.train_dataset.switch_isg2ist()

    def init_epoch_info(self):
        ema_weight = 0.9
        loss_info = defaultdict(lambda: EMA(ema_weight))
        return loss_info

    def init_model(self, **kwargs) -> LowrankModel:
        return initialize_model(self, **kwargs)

    def get_regularizers(self, **kwargs):

        return [
            PlaneTV(kwargs.get('plane_tv_weight', 0.0), what='s_field'),
            PlaneTV(kwargs.get('plane_tv_weight', 0.0), what='d_field'),
            PlaneTV(kwargs.get('plane_tv_weight', 0.0), what='l_field'),

            PlaneTV(kwargs.get('plane_tv_weight_proposal_net', 0.0), what='proposal_network'),
            
            L1TimePlanes(kwargs.get('l1_time_planes', 0.0), what='d_field'),
            L1TimePlanes(kwargs.get('l1_time_planes', 0.0), what='l_field'),


            L1TimePlanes(kwargs.get('l1_time_planes_proposal_net', 0.0), what='proposal_network'),
            TimeSmoothness(kwargs.get('time_smoothness_weight', 0.0), what='d_field'),
            TimeSmoothness(kwargs.get('time_smoothness_weight_proposal_net', 0.0), what='proposal_network'),
            HistogramLoss(kwargs.get('histogram_loss_weight', 0.0)),
            DistortionLoss(kwargs.get('distortion_loss_weight', 0.0)),
            MotionLoss(kwargs.get('motion_loss_weight'),what='d_field'),
            BigMotionLoss(kwargs.get('big_motion_loss_weight'),what='d_field'),

        ]

    @property
    def calc_metrics_every(self):
        return 5


def init_tr_data(data_downsample, data_dir, **kwargs):
    isg = kwargs.get('isg', False)
    ist = kwargs.get('ist', False)
    keyframes = kwargs.get('keyframes', False)
    batch_size = kwargs['batch_size']
    log.info(f"Loading Kpopdataset with downsample={data_downsample}")
    tr_dset = kpop_dataset(
        data_dir, split='train', downsample=data_downsample, num_frames = kwargs['num_frames'],
        batch_size=batch_size,
        max_cameras=kwargs.get('max_train_cameras', None),
        max_tsteps=kwargs['max_train_tsteps'] if keyframes else None,
        isg=isg, keyframes=keyframes, contraction=kwargs['contract'], ndc=kwargs['ndc'],
        near_scaling=float(kwargs.get('near_scaling', 0)), ndc_far=float(kwargs.get('ndc_far', 0)),
        scene_bbox=kwargs['scene_bbox'],
        pose_type=kwargs['pose_type'],
        use_mask=kwargs['use_mask'],
        masked_weight=kwargs['masked_weight']

    )
    
    if ist:
        tr_dset.switch_isg2ist()  # this should only happen in case we're reloading

    g = torch.Generator()
    g.manual_seed(0)
    tr_loader = torch.utils.data.DataLoader(
        tr_dset, batch_size=None, num_workers=4,  prefetch_factor=4, pin_memory=True,
        worker_init_fn=init_dloader_random, generator=g)
    return {"tr_loader": tr_loader, "tr_dset": tr_dset}


def init_ts_data(data_dir, split, **kwargs):
    if 'dnerf' in data_dir:
        downsample = 1.0
    else:
        downsample = 2.0                

    ts_dset = kpop_dataset(
        data_dir, split=split, downsample=downsample, num_frames = kwargs['num_frames'],
        max_cameras=kwargs.get('max_test_cameras', None), max_tsteps=kwargs.get('max_test_tsteps', None),
        contraction=kwargs['contract'], ndc=kwargs['ndc'],
        near_scaling=float(kwargs.get('near_scaling', 0)), ndc_far=float(kwargs.get('ndc_far', 0)),
        scene_bbox=kwargs['scene_bbox'],
        pose_type=kwargs['pose_type'],
        use_mask=kwargs['use_mask'],
        masked_weight=kwargs['masked_weight']
        )
    return {"ts_dset": ts_dset}


def load_data(data_downsample, data_dirs,  render_only, **kwargs):
    assert len(data_dirs) == 1
    od: Dict[str, Any] = {}
    if not render_only:
        od.update(init_tr_data(data_downsample, data_dirs[0], **kwargs))
    else:
        od.update(tr_loader=None, tr_dset=None)
    test_split = 'render' if (render_only ) else 'test'
    od.update(init_ts_data(data_dirs[0], split=test_split, **kwargs))
    return od
