import logging
import os
import time

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam, AdamW
from pl_bolts.optimizers.lars import LARS
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay

from utils.metric import KnnPrecisionMetric
from utils.hydra_utils import save_config_to_disk

import einops
import numpy as np
import random

import pandas as pd


class PretrainingWrapper(pl.LightningModule):
    def __init__(self, cfg, shot_encoder, loss, crn=None):
        super().__init__()

        self.cfg = cfg
        self.shot_encoder = shot_encoder
        self.loss = loss
        self.crn = crn

        self.metric = KnnPrecisionMetric(top_k_list=[1, 5, 10, 15, 20])
        self.best_score = None

        self.num_keyframe = cfg.TRAIN.NUM_KEYFRAME

        self.dtw_preds = {}

        if cfg.LOSS.sampling_method.name == "asymmetric":
            self.cidx = cfg.LOSS.sampling_method.params["asymmetric"].neighbor_left
        else:
            self.cidx = cfg.LOSS.sampling_method.params[cfg.LOSS.sampling_method.name].neighbor_size

        self.use_duration = cfg.LOSS.sampling_method.get("use_duration", False)
        self.use_random = cfg.LOSS.sampling_method.get("use_random", False)

        self.pred_vid = []
        self.pred_sid = []
        self.pred_s_idx = []
        self.pred_bd_idx = []

    
    def on_train_start(self) -> None:
        if self.global_rank == 0:
            try:
                save_config_to_disk(self.cfg)
            except Exception as err:
                logging.info(err)

        self.loss.on_train_start(dist_rank=self.global_rank, device=self.device)


    def extract_shot_representation(self, inputs, is_train=True):
        """
        if is_train == True:
            inputs [b s k c h w] -> output [b s d]
        elif is_train == False:
            inputs [b s k c h w] -> output [b d]
        """

        b, s, k, c, h, w = inputs.shape

        if self.cfg.MODEL.shot_encoder.name in ['resnet', 'vit', 'vit_x_ge']: ### added vit_x_ge 
            inputs = einops.rearrange(inputs, "b s k c h w -> (b s) k c h w", s=s)
            keyframe_repr = [self.shot_encoder(inputs[:, _k]) for _k in range(k)]
            x = torch.stack(keyframe_repr).mean(dim=0)  # [k (b s) d] -> [(b s) d]
            x = einops.rearrange(x, "(b s) d -> b s d", s=s, b=b)
        
        elif self.cfg.MODEL.shot_encoder.name == 'trans4mer':
            inputs = einops.rearrange(inputs, "b s k c h w -> (b s) k c h w", s=s)
            keyframe_repr = [self.shot_encoder(inputs[:, _k], s) for _k in range(k)]
            x = torch.stack(keyframe_repr).mean(dim=0)  # [k (b s) d] -> [(b s) d]
            x = einops.rearrange(x, "(b s) d -> b s d", s=s, b=b)

        else:
            raise NotImplementedError

        return x


    def forward(self, x, **kwargs):
        return self.shot_encoder(x, **kwargs)


    def training_step(self, batch, batch_idx):

        inputs = batch["video"]
        assert len(inputs.shape) == 6, "{}".format(inputs.shape) # [b s k c h w]

        if self.cfg.TRAIN.SHUFFLE.enabled and random.random() < self.cfg.TRAIN.SHUFFLE.probability:
            option = random.choice(["left", "right", "flip"])
            
            n_sparse = 2
            n_dense = 2*self.cidx+1 ### 2*8+1=17
            sparse, dense = torch.split(inputs, [n_sparse, n_dense], dim=1)
            
            if option == "left":
                n_shots = random.randrange(1, self.cidx+1) ### 1~8
                sparse_left, sparse_right = torch.split(sparse, [1, 1], dim=1)
                dense_left, dense_right = torch.split(dense, [n_shots, n_dense-n_shots], dim=1)

                ### shift
                sparse_left = torch.cat((sparse_left[1:,:,:,:,:,:], sparse_left[:1,:,:,:,:,:]), 0)
                dense_left = torch.cat((dense_left[1:,:,:,:,:,:], dense_left[:1,:,:,:,:,:]), 0)

                inputs = torch.cat((sparse_left, sparse_right, dense_left, dense_right), 1)

            elif option == "right":
                n_shots = random.randrange(1, self.cidx+1) ### 1~8
                sparse_left, sparse_right = torch.split(sparse, [1, 1], dim=1)
                dense_left, dense_right = torch.split(dense, [n_dense-n_shots, n_shots], dim=1)

                sparse_right = torch.cat((sparse_right[1:,:,:,:,:,:], sparse_right[:1,:,:,:,:,:]), 0)
                dense_right = torch.cat((dense_right[1:,:,:,:,:,:], dense_right[:1,:,:,:,:,:]), 0)

                inputs = torch.cat((sparse_left, sparse_right, dense_left, dense_right), 1)
                
            elif option == "flip":
                sparse = torch.flip(sparse, [1])
                dense = torch.flip(dense, [1])

                inputs = torch.cat((sparse, dense), 1)
                # pass
            else:
                raise NotImplementedError

        shot_repr = self.extract_shot_representation(inputs)

        if self.cfg.LOSS.sampling_method.name in ["instance", "temporal"]:
            loss = self.loss(shot_repr)
        elif self.cfg.LOSS.sampling_method.name in ["bassl", "shotcol", "bassl+shotcol", "asymmetric"]:
            if self.use_duration or self.use_random:
                loss = self.loss(shot_repr,
                                 crn=self.crn,
                                 mask=batch["mask"],
                                 n_sparse=batch["sparse_idx"].shape[1],
                                 n_dense=batch["dense_idx"].shape[1],
                                 sparse_idx=batch["sparse_idx"],
                                )
            else:
                loss = self.loss(shot_repr,
                                 crn=self.crn,
                                 mask=batch["mask"],
                                 n_sparse=batch["sparse_idx"].shape[1],
                                 n_dense=batch["dense_idx"].shape[1],)
        else:
            raise NotImplementedError

        total_loss = 0
        for k, v in loss.items():
            self.log(k, v, on_step=True, on_epoch=False)
            total_loss += v

        return total_loss

    
    def validation_step(self, batch, batch_idx):
        """ 
        Measure kNN precision during pre-training as validation
        """

        inputs = batch["video"]
        b, s, k, c, h, w = inputs.shape
        assert len(inputs.shape) == 6, "{}".format(inputs.shape) # [b s k c h w]

        global_video_ids = batch["global_video_id"]
        invideo_scene_ids = batch["invideo_scene_id"]

        x = self.extract_shot_representation(inputs, is_train=False) # [b s d]
        x = x[:, self.cidx, :] # [b d]

        x = F.normalize(x, dim=1, p=2)
        for gv_id, vs_id, feat in zip(global_video_ids, invideo_scene_ids, x):
            self.metric.update(gv_id, vs_id, feat)


    # def validation_epoch_end(self, validation_step_outputs):
    def on_validation_epoch_end(self):

        score = {}
        t_s = time.time()
        logging.info(f"[device: {torch.cuda.current_device()}] compute metric scores ...")

        score = self.metric.compute()
        for k, v in score.items():
            self.log(f"pretrain_test/precision_at_{k}", v["precision"],
                    on_epoch=True, prog_bar=True, logger=True, sync_dist=True,)
            if k == 1:
                if self.best_score is None:
                    self.best_score = score
                else:
                    if v["precision"] > self.best_score[1]["precision"]:
                        self.best_score = score
        self.log("pretrain_test/validation_time_min", float(time.time() - t_s) / 60,
                 on_epoch=True, prog_bar=False, logger=True, sync_dist=True,)
        torch.cuda.synchronize()
        self.metric.reset()
        logging.info(dict(score))


    def test_step(self, batch, batch_idx):
        """ we extract shot representation and save it.  """
        
        inputs = batch["video"]
        b, s, k, c, h, w = inputs.shape
        assert len(inputs.shape) == 6, "{}".format(inputs.shape) # [b s k c h w]

        vids = batch["vid"]
        sids = batch["sid"]

        x = self.extract_shot_representation(inputs, is_train=False) # [b s d]
        embedding = x.float().cpu().numpy()

        for vid, sid, feat in zip(vids, sids, embedding):
            os.makedirs(os.path.join(self.cfg.FEAT_PATH, self.cfg.LOAD_FROM, vid), exist_ok=True)
            new_filename = f"{vid}/shot_{sid}"
            new_filepath = os.path.join(self.cfg.FEAT_PATH, self.cfg.LOAD_FROM, new_filename)
            np.save(new_filepath, feat)

    def on_predict_epoch_end(self, arg):
        df = pd.DataFrame()
        df["video_id"] = self.pred_vid
        df["shot_id"] = self.pred_sid
        df["anchors"] = self.pred_s_idx
        df["bd_idx"] = self.pred_bd_idx
        df.to_json("/dev/shm/anno.pseudo_boundary.ndjson", orient='records', lines=True)
    
    def predict_step(self, batch, batch_idx):
        """ we extract shot representation and save it.  """
        
        # inputs = batch["video"]
        # b, s, k, c, h, w = inputs.shape
        # assert len(inputs.shape) == 6, "{}".format(inputs.shape) # [b s k c h w]

        vids = batch["vid"]
        sids = batch["sid"]

        # x = self.extract_shot_representation(inputs, is_train=False) # [b s d]
        embedding = []
        for vid, sid in zip(vids, sids):
            # os.makedirs(os.path.join(self.cfg.FEAT_PATH, self.cfg.LOAD_FROM, vid), exist_ok=True)
            new_filename = f"{vid}/shot_{sid}.npy"
            new_filepath = os.path.join(self.cfg.FEAT_PATH, self.cfg.LOAD_FROM, new_filename)
            feat = np.load(new_filepath)
            embedding.append(feat)

        embedding = np.array(embedding)
        sparse_idx = batch["sparse_idx"]
        
        x = torch.from_numpy(embedding).to(device=sparse_idx.device)

        shot_repr = self.loss.head_nce(x)
        # s_emb = shot_repr[:,sparse_idx[:,],:]
        b, s, d = shot_repr.shape
        expanded_idx = sparse_idx.unsqueeze(-1).expand(-1, -1, d) 
        s_emb = torch.gather(shot_repr, dim=1, index=expanded_idx)
        d_emb = shot_repr

        # dtw_path = self.loss._compute_dtw_path(s_emb, d_emb)
        dtw_path = self.loss._compute_dtw_path(s_emb, d_emb, sparse_idx.cpu().numpy())
        bd_idxs = self.loss._compute_boundary(dtw_path)

        sparse_idxs = sparse_idx.cpu().numpy()
        
        for vid, sid, s_idx, bd_idx  in zip(vids, sids, sparse_idxs, bd_idxs):
            self.pred_vid.append(vid)
            self.pred_sid.append(sid)
            self.pred_s_idx.append(s_idx)
            self.pred_bd_idx.append(bd_idx)
            
    
    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list):
        params = []
        excluded_params = []

        for name, param in named_params:
            if not param.requires_grad:
                continue
            elif any(layer_name in name for layer_name in skip_list):
                excluded_params.append(param)
            else:
                params.append(param)

        return [
            {"params": params, "weight_decay": weight_decay},
            {"params": excluded_params, "weight_decay": 0.0},
        ]

    def configure_optimizers(self):
        # params
        skip_list = []
        weight_decay = self.cfg.TRAIN.OPTIMIZER.weight_decay
        if not self.cfg.TRAIN.OPTIMIZER.regularize_bn:
            skip_list.append("bn")
        if not self.cfg.TRAIN.OPTIMIZER.regularize_bias:
            skip_list.append("bias")
        params = self.exclude_from_wt_decay(
            self.named_parameters(), weight_decay=weight_decay, skip_list=skip_list
        )

        # optimizer
        if self.cfg.TRAIN.OPTIMIZER.name == "sgd":
            optimizer = SGD(
                params,
                lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr,
                momentum=0.9,
                weight_decay=weight_decay,
            )
        elif self.cfg.TRAIN.OPTIMIZER.name == "lars":
            optimizer = LARS(
                params,
                lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr,
                momentum=0.9,
                weight_decay=weight_decay,
                trust_coefficient=0.001,
            )
        elif self.cfg.TRAIN.OPTIMIZER.name == "adam":
            optimizer = Adam(
                params,
                lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr,
                # momentum=0.9,
                weight_decay=weight_decay,
            )
        elif self.cfg.TRAIN.OPTIMIZER.name == "adamw":
            optimizer = AdamW(
                params,
                lr=self.cfg.TRAIN.OPTIMIZER.lr.scaled_lr,
                # momentum=0.9,
                weight_decay=weight_decay,
            )    
        else:
            raise NotImplementedError
            
        warmup_steps = int(
            self.cfg.TRAIN.TRAIN_ITERS_PER_EPOCH
            * self.cfg.TRAINER.max_epochs
            * self.cfg.TRAIN.OPTIMIZER.scheduler.warmup
        )
        total_steps = int(
            self.cfg.TRAIN.TRAIN_ITERS_PER_EPOCH * self.cfg.TRAINER.max_epochs
        )

        if self.cfg.TRAIN.OPTIMIZER.scheduler.name == "cosine_with_linear_warmup":
            scheduler = {
                "scheduler": torch.optim.lr_scheduler.LambdaLR(
                    optimizer,
                    linear_warmup_decay(warmup_steps, total_steps, cosine=True),
                ),
                "interval": "step",
                "frequency": 1,
            }
        else:
            raise NotImplementedError

        return [optimizer], [scheduler]
