import math, bisect
from argparse import ArgumentParser
from typing import Callable, Optional

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.optimizer import LightningOptimizer
from torch import nn
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer

from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.transforms.dataset_normalizations import (
    cifar10_normalization,
    imagenet_normalization,
    stl10_normalization,
)

from plb.models.encoder import Transformer
from plb.models.validator import construct_validator
from common_pytorch.networks.temporal_conv_net import get_default_net_config as get_default_tconv_net_config
from common_pytorch.networks.temporal_conv_net import get_net as get_tconv_net

big_number = 2 ** 13  # a number >> T
use_ce = True

class SyncFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor):
        # gather sizes on different GPU's
        size = torch.tensor(tensor.size(0), device=tensor.device)
        gathered_size = [torch.zeros_like(size) for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(gathered_size, size)
        ctx.sizes = [_.item() for _ in gathered_size]
        max_bs = max(ctx.sizes)

        gathered_tensor = [tensor.new_zeros((max_bs,) + tensor.shape[1:]) for _ in
                           range(torch.distributed.get_world_size())]
        tbg = torch.cat([tensor, tensor.new_zeros((max_bs - tensor.size(0),) + tensor.shape[1:])], dim=0)
        torch.distributed.all_gather(gathered_tensor, tbg)
        gathered_tensor = torch.cat([_[:s] for (_, s) in zip(gathered_tensor, ctx.sizes)], 0)
        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
        my_rank = torch.distributed.get_rank()
        idx_from = sum(ctx.sizes[:my_rank])
        idx_to = idx_from + ctx.sizes[my_rank]
        return grad_input[idx_from:idx_to]


class Projection(nn.Module):

    def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128):
        super().__init__()
        self.output_dim = output_dim
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.model = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim), nn.BatchNorm1d(self.hidden_dim), nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim, bias=False)
        )

    def forward(self, x):
        x = self.model(x)
        return x


class TCC(pl.LightningModule):
    def __init__(
            self,
            gpus: int,
            num_samples: int,
            batch_size: int,
            length: int,
            dataset: str,
            num_nodes: int = 1,
            arch: str = 'resnet50',
            hidden_mlp: int = 512,  # 2048, this is revised
            feat_dim: int = 128,
            warmup_epochs: int = 10,
            max_epochs: int = 100,
            temperature: float = 0.1,
            first_conv: bool = True,
            maxpool1: bool = True,
            optimizer: str = 'adam',
            lars_wrapper: bool = True,
            exclude_bn_bias: bool = False,
            start_lr: float = 0.,
            learning_rate: float = 1e-3,
            final_lr: float = 0.,
            weight_decay: float = 1e-6,
            val_configs=None,
            log_dir=None,
            protection=-1,  # no such concept, to be removed
            tr_layer=6,
            tr_dim=512,
            neg_dp=0.0,
            **kwargs
    ):
        """
        Args:
            batch_size: the batch size
            num_samples: num samples in the dataset
            warmup_epochs: epochs to warmup the lr for
            lr: the optimizer learning rate
            opt_weight_decay: the optimizer weight decay
            loss_temperature: the loss temperature
        """
        super().__init__()
        self.save_hyperparameters()

        self.gpus = gpus
        self.num_nodes = num_nodes
        self.arch = arch
        self.dataset = dataset
        self.num_samples = num_samples
        self.batch_size = batch_size  # batch size from the view of scheduler
        self.real_batch_size = batch_size * length  # batch size from the view of optimizer

        self.hidden_mlp = hidden_mlp
        self.feat_dim = feat_dim
        self.first_conv = first_conv
        self.maxpool1 = maxpool1

        self.optim = optimizer
        self.lars_wrapper = lars_wrapper
        self.exclude_bn_bias = exclude_bn_bias
        self.weight_decay = weight_decay
        self.temperature = temperature

        self.start_lr = start_lr / 256 * self.real_batch_size
        self.final_lr = final_lr / 256 * self.real_batch_size
        self.learning_rate = learning_rate / 256 * self.real_batch_size
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.log_dir = log_dir

        self.encoder = self.init_model(tr_layer=tr_layer, tr_dim=tr_dim)

        self.projection = Projection(input_dim=tr_dim, hidden_dim=tr_dim, output_dim=self.feat_dim)
        # originally using hidden_mlp for input_dim and hidden_dim

        # compute iters per epoch
        global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size * torch.cuda.device_count()
        if global_batch_size != 0:
            self.train_iters_per_epoch = math.ceil(self.num_samples / global_batch_size)
        else:
            self.train_iters_per_epoch = 0

        # define LR schedule
        warmup_lr_schedule = np.linspace(self.start_lr, self.learning_rate,
                                         self.train_iters_per_epoch * self.warmup_epochs)
        iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))
        cosine_lr_schedule = np.array([
            self.final_lr + 0.5 * (self.learning_rate - self.final_lr) *
            (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))))
            for t in iters
        ])

        self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))

        # construct validator
        self.validators = []
        if val_configs is not None and torch.cuda.device_count() != 0:
            for k, val in val_configs.items():
                val["log_dir"] = self.log_dir
                val["rank"] = self.global_rank
                val["world_size"] = torch.cuda.device_count()
                self.validators.append(construct_validator(k, val))

    def init_model(self, tr_layer, tr_dim):
        if self.arch == "Transformer":
            return Transformer(tr_layer, tr_dim)
        elif self.arch == "Tconv":
            # TODO: move to config
            config = get_default_tconv_net_config()
            config.tempconv_dim_in = 51
            config.tempconv_dim_out = 512
            config.tempconv_filter_widths = [3, 3, 3, 3, 3, 3, 3, 3]
            config.tempconv_channels = 1024
            return get_tconv_net(config)
        else:
            assert 0, "Unknown model!"

    def forward(self, *args):
        x = self.encoder(*args)  # [N, T, f]
        if self.arch == "Tconv":
            x = x.permute(2, 0, 1).contiguous()
        return x

    def shared_step(self, batch):
        # img1, img2: [B, maxT1, 51], [B, maxT2, 51], maxT1 >= l1b, maxT2 >= l2b, any b in B
        # len1, len2: [B] of ints, real lengths, l1B, l2B
        # velo1, velo2: [B, maxT1], [B, maxT2], corresponding indices to video before temporal augmentation
        # m: [t1, t2]: real number between 0 and 1, t1 = sum_B l1b, t1 = sum_B l2b, composed of diagonal rectangles
        # chopped_bs: the batch size after reducing length difference to squares
        img1, img2, len1, len2, m, indices1, indices2, chopped_bs = batch
        # len1 and len2 actually the same

        h1_ = self(img1, len1)  # [maxT1, B, f=512]
        h2_ = self(img2, len2)
        h1_ = h1_.permute(1, 0, 2).contiguous()  # [B, maxT1, f=512]
        h2_ = h2_.permute(1, 0, 2).contiguous()

        h1 = h1_.flatten(0, 1)[indices1]
        h2 = h2_.flatten(0, 1)[indices2]
        z1 = self.projection(h1)
        z2 = self.projection(h2)
        loss = self.classification_loss(z1, z2, chopped_bs)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)

        # log LR (LearningRateLogger callback doesn't work with LARSWrapper)
        self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False)

        self.log('train_loss', loss, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        # TODO: add a tensorboard figure logger here
        self.log('val_loss', loss, on_step=False, on_epoch=True, sync_dist=True)
        return loss

    def on_validation_epoch_end(self, ):
        device = self.device
        self.eval()
        # self.cpu()
        for validator in self.validators:
            if self.global_rank != 0:
                save = -1
            else:
                if self.current_epoch % 500 == 474:  # note it has to be a subset of 5Z - 1 on Azure, debug time use 0
                    save = self.current_epoch
                else:
                    save = -1

            metric_dict = validator(self, save=save)
            for name, metric in metric_dict.items():
                metric = torch.tensor([metric], device=device)
                self.log(name, metric, on_step=False, on_epoch=True, sync_dist=True)
        # self.to(device)

    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
        params = []
        excluded_params = []

        for name, param in named_params:
            if not param.requires_grad:
                continue
            elif any(layer_name in name for layer_name in skip_list):
                excluded_params.append(param)
            else:
                params.append(param)

        return [{
            'params': params,
            'weight_decay': weight_decay
        }, {
            'params': excluded_params,
            'weight_decay': 0.,
        }]

    def configure_optimizers(self):
        if self.exclude_bn_bias:
            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
        else:
            params = self.parameters()

        if self.optim == 'sgd':
            optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay)
        elif self.optim == 'adam':
            # optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)
            optimizer = torch.optim.AdamW(params, lr=self.learning_rate, weight_decay=self.weight_decay)

        if self.lars_wrapper:
            optimizer = LARSWrapper(
                optimizer,
                eta=0.001,  # trust coefficient
                clip=False
            )

        return optimizer

    def optimizer_step(
            self,
            epoch: int = None,
            batch_idx: int = None,
            optimizer: Optimizer = None,
            optimizer_idx: int = None,
            optimizer_closure: Optional[Callable] = None,
            on_tpu: bool = None,
            using_native_amp: bool = None,
            using_lbfgs: bool = None,
    ) -> None:
        # warm-up + decay schedule placed here since LARSWrapper is not optimizer class
        # adjust LR of optim contained within LARSWrapper
        for param_group in optimizer.param_groups:
            param_group["lr"] = self.lr_schedule[self.trainer.global_step]  # // torch.cuda.device_count()]

        # rank = torch.distributed.get_rank()
        # print(f"I am with rank {rank} and I am at global step {self.trainer.global_step}")

        # from lightning
        if not isinstance(optimizer, LightningOptimizer):
            # wraps into LightingOptimizer only for running step
            optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
        optimizer.step(closure=optimizer_closure)

    def classification_loss(self, z1, z2, chopped_bs):
        # follows the loss calculation of TCC classification variant
        # [B, f], [B, f], a list of segment lengths contained in B
        dist = torch.linalg.norm(z1.unsqueeze(1) - z2.unsqueeze(0), dim=-1) ** 2  # [B, B]
        mask = torch.block_diag(*[z1.new_ones((int(l), int(l))) for l in chopped_bs])  # [B, B]
        gather_mask = torch.softmax(- dist - (1 - mask) * big_number, dim=-1)  # [B, B], with non-block-diag zero
        v_tilde = gather_mask @ z2  # [B, f]

        dist_again = torch.linalg.norm(z1.unsqueeze(1) - v_tilde.unsqueeze(0), dim=-1) ** 2  # [B, B]
        y_hat = torch.softmax(- dist_again - (1 - mask) * big_number, dim=-1)  # [B, B], with non-block-diag zero
        if use_ce:
            ce = F.cross_entropy(y_hat, torch.arange(z1.size(0), device=z1.device))
        else:
            ce = F.binary_cross_entropy(y_hat, torch.diag(torch.ones((z1.size(0)), device=z1.device)))
        return ce
