import math
import time

from fairseq_manual.data_utils import compute_mask_indices
from hydra.utils import instantiate
import torch
import torch.distributed as dist
from torch.optim import AdamW
from pytorch_lightning import LightningModule
import random

from schedulers.warmup_cosine import WarmupCosineScheduler

from espnet.asr.asr_utils import add_results_to_json, torch_load
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.pytorch_backend.lm.transformer import TransformerLM
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.scorers.length_bonus import LengthBonus
from metrics import WER
from schedulers.warmup_cosine import WarmupCosineMomentumScheduler, ConstantThenLinearScheduler
from utils.utils import ids_to_str, ids_to_str_with_delim, set_requires_grad, UNIGRAM1000_LIST, get_param_groups

from data.transforms import AdaptiveLengthTimeMask


class SSLLearner(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.save_hyperparameters(cfg)
        self.cfg = cfg

        if cfg.compile_model:
            self.model = torch.compile(instantiate(cfg.model.obj, cfg))
        else:
            self.model = instantiate(cfg.model.obj, cfg)

        if cfg.model.pretrained_model_path:
            print("Load pretrained model weights")
            if ".ckpt" in cfg.model.pretrained_model_path:
                ckpt = torch.load(cfg.model.pretrained_model_path, map_location=lambda storage, loc: storage, weights_only=False)["state_dict"]
            else:
                ckpt = torch.load(cfg.model.pretrained_model_path, map_location=lambda storage, loc: storage)
            

            if cfg.model.transfer_only_encoder:                
                ckpt = {k[39:]: v for k, v in ckpt.items() if k.startswith('model._orig_mod.model.backbone.encoder')}
                ckpt = {k: v for k, v in ckpt.items() if not k.startswith("after_norm")}
                self.model.model.backbone.encoder.load_state_dict(ckpt, strict=False)
                self.model.model.target_backbone.encoder.load_state_dict(ckpt, strict=False)
            else:
                self.model.load_state_dict(ckpt, strict=False)

        if cfg.debug.log_gradients:
            self.logger.experiment.watch(self.model, log="gradients")
        
        self.ignore_id = -1
        self.beam_search_video = self.get_beam_search(self.model.model.backbone)
        self.beam_search_audio = self.get_beam_search(self.model.model.backbone)
        self.beam_search_av = self.get_beam_search(self.model.model.backbone)
        self.wer_video = WER()
        self.wer_audio = WER()
        self.wer_av = WER()

        self.wer_att = WER()

        self.mask_video = AdaptiveLengthTimeMask(
            window=int(cfg.data.timemask_window * 25),
            stride=int(cfg.data.timemask_stride * 25),
            replace_with_zero=True
        )
        self.mask_audio = AdaptiveLengthTimeMask(
            window=int(cfg.data.timemask_window_audio * 16_000),
            stride=int(cfg.data.timemask_stride_audio * 16_000),
            replace_with_zero=True
        )
        self.automatic_optimization = False
    

    def get_beam_search(self, model):
        token_list = UNIGRAM1000_LIST

        odim = len(token_list)
        self.token_list = token_list

        scorers = model.scorers()

        if self.cfg.decode.lm_weight and self.cfg.model.pretrained_lm_path:
            lm = TransformerLM(len(token_list), self.cfg.model.language_model)
            set_requires_grad(lm, False)
            print("Load pretrained language model weights")
            torch_load(self.cfg.model.pretrained_lm_path, lm)
        else:
            lm = None

        scorers["lm"] = lm
        scorers["length_bonus"] = LengthBonus(len(token_list))

        weights = dict(
            decoder=1.0 - self.cfg.decode.ctc_weight,
            ctc=self.cfg.decode.ctc_weight,
            lm=self.cfg.decode.lm_weight,
            length_bonus=self.cfg.decode.penalty,
        )
        beam_search = BatchBeamSearch(
            beam_size=self.cfg.decode.beam_size,
            vocab_size=len(token_list),
            weights=weights,
            scorers=scorers,
            sos=odim - 1,
            eos=odim - 1,
            token_list=token_list,
            pre_beam_score_key=None if self.cfg.decode.ctc_weight == 1.0 else "decoder",
        )

        return beam_search

    def get_mask(self, data, padding_mask, mask_prob, mask_length):
        B, C, T, H, W = data["video"].shape
        mask = ~compute_mask_indices(
            (B, T),
            ~padding_mask,
            mask_prob,
            mask_length,
            min_masks=1
        )
        return torch.from_numpy(mask).to(data["video"].device)
            
    def training_step(self, data, batch_idx):
        opt = self.optimizers()
        sch = self.lr_schedulers()

        # zero grad
        opt.zero_grad()  

        data_u, data_l = data["unlabelled"], data["labelled"]
        label_l = data_l["label"].squeeze(1)

        video = data_u["video_aug"].squeeze(1)
        audio = data_u["audio_aug"].transpose(1, 2)

        new_video = []
        new_audio = []
        for b in range(len(video)):
            new_video.append(self.mask_video(video[b].transpose(0, 1)).transpose(0, 1))
            new_audio.append(self.mask_audio(audio[b].transpose(0, 1)).transpose(0, 1))
        video = torch.stack(new_video)
        audio = torch.stack(new_audio)

        x_v, x_a, x_av = self.model.model.get_encoded_features(video, audio, self.padding_mask_u.unsqueeze(-2), True)

        loss = 0.0
        if self.cfg.model.ctc_rel_weight > 0.0:
            self.log(
                "conf_ctc_u", 
                self.conf_ctc.mean(), 
                on_step=True, 
                on_epoch=True, 
                prog_bar=True, 
                batch_size=len(label_l), 
                sync_dist=True
            )

            self.log(
                "avg_above_tresh_ctc_u", self.mask_conf_ctc.sum() / len(self.mask_conf_ctc), 
                on_step=False,
                on_epoch=True, 
                batch_size=len(label_l), 
                sync_dist=True
            )
            
            loss_ctc_v, loss_ctc_a, loss_ctc_av = self.model.model.get_encoder_losses(
                x_v,
                x_a,
                x_av,
                self.padding_mask_u, 
                self.targets_ctc,
                False,
                self.mask_conf_ctc,
                labels_aux=self.targets_ctc_aux,
                mask_conf_aux=self.mask_conf_ctc_aux,
            )
            self.log(
                "loss_ctc_v_u", loss_ctc_v, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "loss_ctc_a_u", loss_ctc_a, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "loss_ctc_av_u", loss_ctc_av, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True
            )

            loss += self.cfg.model.unlab_rel_weight_v*self.cfg.model.ctc_rel_weight*self.cfg.model.v_rel_weight*loss_ctc_v
            loss += self.cfg.model.unlab_rel_weight_a*self.cfg.model.ctc_rel_weight*(1 - self.cfg.model.v_rel_weight)*loss_ctc_a
            loss += self.cfg.model.unlab_rel_weight_a*self.cfg.model.ctc_rel_weight*(1 - self.cfg.model.v_rel_weight)*loss_ctc_av

        if self.cfg.model.ctc_rel_weight < 1.0:
            mask_conf = self.probs_out_max.ge(self.cfg.model.conf_threshold)

            self.log(
                "avg_above_tresh_u", mask_conf.sum() / (mask_conf.size(0)*mask_conf.size(1)), 
                on_step=False, 
                on_epoch=True, 
                batch_size=len(label_l), 
                sync_dist=True
            )

            loss_att_v, loss_att_a, loss_att_av, acc_v, acc_a, acc_av = self.model.model.get_decoder_losses(
                x_v,
                x_a,
                x_av,
                self.padding_mask_u.unsqueeze(-2), 
                self.targets_out,
                False, 
                self.targets_in,
                mask_conf,
                labels_aux=self.targets_out_aux,
                mask_conf_aux=self.mask_conf_aux,
            )

            self.log(
                "loss_att_v_u", loss_att_v, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "loss_att_a_u", loss_att_a, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "loss_att_av_u", loss_att_av, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "acc_v_u", acc_v, on_step=False, on_epoch=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "acc_a_u", acc_a, on_step=False, on_epoch=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "acc_av_u", acc_av, on_step=False, on_epoch=True, batch_size=len(label_l), sync_dist=True
            )
            self.log(
                "avg_max_prob_u", 
                self.probs_out_max.mean(), 
                on_step=True, 
                on_epoch=True, 
                prog_bar=True, 
                batch_size=len(label_l), 
                sync_dist=True
            )

            loss += self.cfg.model.unlab_rel_weight_v*(1-self.cfg.model.ctc_rel_weight)*self.cfg.model.v_rel_weight*loss_att_v
            loss += self.cfg.model.unlab_rel_weight_a*(1-self.cfg.model.ctc_rel_weight)*(1-self.cfg.model.v_rel_weight)*loss_att_a
            loss += self.cfg.model.unlab_rel_weight_a*(1-self.cfg.model.ctc_rel_weight)*(1-self.cfg.model.v_rel_weight)*loss_att_av

        # backward
        self.manual_backward(loss)

        video = data_l["video_aug"].squeeze(1)
        audio = data_l["audio_aug"].transpose(1, 2)

        new_video = []
        new_audio = []
        for b in range(len(video)):
            new_video.append(self.mask_video(video[b].transpose(0, 1)).transpose(0, 1))
            new_audio.append(self.mask_audio(audio[b].transpose(0, 1)).transpose(0, 1))
        video = torch.stack(new_video)
        audio = torch.stack(new_audio)

        x_v, x_a, x_av = self.model.model.get_encoded_features(video, audio, self.padding_mask_l.unsqueeze(-2))

        loss = 0.0
        if self.cfg.model.ctc_rel_weight > 0.0:
            loss_ctc_v, loss_ctc_a, loss_ctc_av = self.model.model.get_encoder_losses(
                x_v,
                x_a,
                x_av,
                self.padding_mask_l, 
                label_l,
                True
            )
            self.log("loss_ctc_v_l", loss_ctc_v, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True)
            self.log("loss_ctc_a_l", loss_ctc_a, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True)
            self.log("loss_ctc_av_l", loss_ctc_av, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True)
        
            loss += (1-self.cfg.model.unlab_rel_weight_v)*self.cfg.model.ctc_rel_weight*self.cfg.model.v_rel_weight*loss_ctc_v 
            loss += (1-self.cfg.model.unlab_rel_weight_a)*self.cfg.model.ctc_rel_weight*(1-self.cfg.model.v_rel_weight)*loss_ctc_a
            loss += (1-self.cfg.model.unlab_rel_weight_a)*self.cfg.model.ctc_rel_weight*(1-self.cfg.model.v_rel_weight)*loss_ctc_av

        if self.cfg.model.ctc_rel_weight < 1.0:
            loss_att_v, loss_att_a, loss_att_av, acc_v, acc_a, acc_av = self.model.model.get_decoder_losses(
                x_v,
                x_a,
                x_av,
                self.padding_mask_l.unsqueeze(-2), 
                label_l,
                True, 
            )

            self.log("loss_att_v_l", loss_att_v, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True)
            self.log("loss_att_a_l", loss_att_a, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True)
            self.log("loss_att_av_l", loss_att_av, on_step=True, on_epoch=True, prog_bar=True, batch_size=len(label_l), sync_dist=True)
            self.log("acc_v_l", acc_v, on_step=False, on_epoch=True, batch_size=len(label_l), sync_dist=True)
            self.log("acc_a_l", acc_a, on_step=False, on_epoch=True, batch_size=len(label_l), sync_dist=True)
            self.log("acc_av_l", acc_av, on_step=False, on_epoch=True, batch_size=len(label_l), sync_dist=True)

            loss += (1-self.cfg.model.unlab_rel_weight_v)*(1-self.cfg.model.ctc_rel_weight)*self.cfg.model.v_rel_weight*loss_att_v
            loss += (1-self.cfg.model.unlab_rel_weight_a)*(1-self.cfg.model.ctc_rel_weight)*(1-self.cfg.model.v_rel_weight)*loss_att_a
            loss += (1-self.cfg.model.unlab_rel_weight_a)*(1-self.cfg.model.ctc_rel_weight)*(1-self.cfg.model.v_rel_weight)*loss_att_av

        self.manual_backward(loss)
        self.clip_gradients(opt, gradient_clip_val=self.cfg.optimizer.gradient_clip_val, gradient_clip_algorithm="norm")
        opt.step()
        sch.step()

        self.log('monitoring_step', self.trainer.global_step)  # this is to save the last k checkpoints

    @torch.no_grad()
    def on_train_batch_start(self, data, batch_idx):
        data_u, data_l = data["unlabelled"], data["labelled"]
        self.padding_mask_u = make_non_pad_mask(data_u["video_lengths"]).to(data_u["video"].device)
        self.padding_mask_l = make_non_pad_mask(data_l["video_lengths"]).to(data_l["video"].device)

        video = data_u["video"].squeeze(1)
        audio = data_u["audio"].transpose(1, 2)

        e_a = self.model.model.get_target_features(
            video, audio, self.padding_mask_u.unsqueeze(-2) 
        )

        if self.cfg.model.ctc_rel_weight > 0.0:
            self.ctc_at_u = self.model.model.get_encoder_targets(e_a)
        
        ctc_probs = self.ctc_at_u.softmax(-1)
        ctc_log_probs = (ctc_probs.max(-1)[0] + 1e-9).masked_fill(~self.padding_mask_u, 1.0).log()
        self.conf_ctc = torch.exp(ctc_log_probs.sum(-1) / self.padding_mask_u.sum(-1))
        self.mask_conf_ctc = self.conf_ctc.ge(self.cfg.model.conf_threshold_ctc)

        self.targets_ctc, mask_lengths_ctc, ctc_symbol_probs = self.model.model.get_ctc_targets_with_probs(
            ctc_probs, self.padding_mask_u
        )

        max_sequence_lengths = ((mask_lengths_ctc.sum(dim=-1) + 1)*1.2).long().clamp(max=e_a.size(1) // 3)
        # max_sequence_lengths = None
        if self.cfg.model.ctc_rel_weight < 1.0:
            self.ar_prob = self.ar_scheduler.get_lr(self.trainer.global_step)
            self.log("ar_prob", self.ar_prob, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
            self.mask_conf_aux = self.targets_out_aux = self.mask_conf_ctc_aux = self.targets_ctc_aux = None

            rand_val = random.random()

            if rand_val < self.ar_prob:
                self.logits_at_u, mask_lengths_att = self.model.model.get_decoder_targets(
                    e_a, self.padding_mask_u.unsqueeze(-2), max_sequence_lengths
                )

                self.targets_out = self.logits_at_u.argmax(-1)
                self.targets_out = torch.where(mask_lengths_att, self.targets_out, -1)
                self.probs_out_max = self.logits_at_u.softmax(-1).max(-1)[0]

                self.targets_in = self.targets_out.clone()
                indices = mask_lengths_att.sum(-1) - 1
                self.targets_in[torch.arange(len(indices)), indices] = self.ignore_id
                self.targets_in = self.targets_in[:, :-1]

                if self.cfg.model.use_att_targets_for_ctc:
                    self.targets_ctc_aux = self.targets_in
                    ctc_probs = self.logits_at_u.softmax(-1)
                    ctc_log_probs = (ctc_probs.max(-1)[0] + 1e-9).masked_fill(~mask_lengths_att, 1.0).log()
                    self.conf_ctc_aux = torch.exp(ctc_log_probs.sum(-1) / mask_lengths_att.sum(-1))
                    self.mask_conf_ctc_aux = self.conf_ctc_aux.ge(self.cfg.model.conf_threshold)
            else:
                self.targets_out, self.probs_out_max = self.model.model.target_backbone.forward_attention_av(
                    e_a, self.padding_mask_u.unsqueeze(-2), self.targets_ctc
                )
                self.targets_in = self.targets_ctc

                if self.cfg.model.use_ctc_targets_for_decoder:
                    self.targets_out_aux = self.targets_ctc
                    zeros = torch.zeros((len(ctc_symbol_probs), 1), dtype=ctc_symbol_probs.dtype, device=ctc_symbol_probs.device)
                    ctc_symbol_probs = torch.cat([ctc_symbol_probs, zeros], dim=-1)
                    self.mask_conf_aux = ctc_symbol_probs.ge(self.cfg.model.conf_threshold_ctc)
                    self.mask_conf_aux[torch.arange(len(self.mask_conf_aux)), mask_lengths_ctc.sum(dim=-1)] = True  # have True for EOS
              
    def on_train_batch_end(self, *args):
        momentum = self.momentum_scheduler.get_lr(self.trainer.global_step)  # global step takes into account 2 optimizer steps in PL > 1.5
        self.model.update_moving_average(momentum)
        self.log("momentum", momentum, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)

    def shared_val_test_step(self, data, dataloader_idx=0):
        audio, label = data["audio"], data["label"]
        if dataloader_idx == 0:
            video = data["video"]
            padding_mask_v = make_non_pad_mask(data["video_lengths"]).to(data["video"].device).unsqueeze(-2)

            features_v, features_a, features_av = self.model.model.get_encoded_features(video.squeeze(1), audio.transpose(1, 2), padding_mask_v)

            if self.cfg.model.ctc_rel_weight > 0.0:
                loss_ctc_v = self.model.model.backbone.ctc_v(
                    features_v, torch.tensor(data["video_lengths"], device=features_v.device), data["label"].squeeze(1)
                )
                loss_ctc_a = self.model.model.backbone.ctc_a(
                    features_a, torch.tensor(data["video_lengths"], device=features_a.device), data["label"].squeeze(1)
                )
                loss_ctc_av = self.model.model.backbone.ctc_av(
                    features_av, torch.tensor(data["video_lengths"], device=features_a.device), data["label"].squeeze(1)
                )

                self.log("loss_ctc_v_val", loss_ctc_v, batch_size=len(label), sync_dist=True)
                self.log("loss_ctc_a_val", loss_ctc_a, batch_size=len(label), sync_dist=True)
                self.log("loss_ctc_av_val", loss_ctc_av, batch_size=len(label), sync_dist=True)

            if self.cfg.model.ctc_rel_weight < 1.0:
                acc_v, acc_a, acc_av = self.model.model.backbone.forward_labelled(
                    features_v, features_a, features_av, padding_mask_v, label
                )[-3:]

                self.log("acc_video_val", acc_v, batch_size=len(label), sync_dist=True)
                self.log("acc_audio_val", acc_a, batch_size=len(label), sync_dist=True)
                self.log("acc_av_val", acc_av, batch_size=len(label), sync_dist=True)
        
        if dataloader_idx == 1:
            lengths = torch.tensor(data["audio_lengths"], device=data["audio"].device) // 640
            padding_mask = make_non_pad_mask(lengths).to(lengths.device).unsqueeze(-2)
            features_a = self.model.model.get_encoded_features_libri(audio.transpose(1, 2), padding_mask)
            logits_a, mask_lengths_att = self.model.model.get_decoder_targets(
                features_a, padding_mask, None
            )
            targets = logits_a.argmax(-1)

            for lab, targ, mask in zip(label, targets, mask_lengths_att):
                groundtruth = lab
                transcription = ids_to_str(targ[mask][:-1], self.token_list) 
                transcription = transcription.replace("▁", " ").strip()

                self.wer_att.update(transcription, groundtruth)

    def validation_step(self, data, batch_idx, dataloader_idx=0):
        self.shared_val_test_step(data, dataloader_idx)

    def calculate_wer(self, video, audio, padding_mask, labels, batch_idx):
        labels = labels.squeeze(1)
        for vid, aud, label, mask in zip(video, audio, labels, padding_mask):
            feat_v, feat_a, feat_av = self.model.model.get_encoded_features(
                vid.unsqueeze(0), aud.unsqueeze(0), mask.unsqueeze(0).unsqueeze(-2)
            )
            
            nbest_hyps_v = self.beam_search_video(
                    x=feat_v.squeeze(0),
                    modality="v",
                    maxlenratio=self.cfg.decode.maxlenratio,
                    minlenratio=self.cfg.decode.minlenratio
                )
            nbest_hyps_a = self.beam_search_audio(
                    x=feat_a.squeeze(0),
                    modality="a",
                    maxlenratio=self.cfg.decode.maxlenratio,
                    minlenratio=self.cfg.decode.minlenratio
                )
            nbest_hyps_av = self.beam_search_av(
                    x=feat_av.squeeze(0),
                    modality="av",
                    maxlenratio=self.cfg.decode.maxlenratio,
                    minlenratio=self.cfg.decode.minlenratio
                )
            
            nbest_hyps_v = [
                h.asdict() for h in nbest_hyps_v[: min(len(nbest_hyps_v), 1)]
            ]
            nbest_hyps_a = [
                h.asdict() for h in nbest_hyps_a[: min(len(nbest_hyps_a), 1)]
            ]
            nbest_hyps_av = [
                h.asdict() for h in nbest_hyps_av[: min(len(nbest_hyps_av), 1)]
            ]

            transcription_v = add_results_to_json(nbest_hyps_v, self.token_list)
            transcription_v = transcription_v.replace("<eos>", "")
            transcription_a = add_results_to_json(nbest_hyps_a, self.token_list)
            transcription_a = transcription_a.replace("<eos>", "")
            transcription_av = add_results_to_json(nbest_hyps_av, self.token_list)
            transcription_av = transcription_av.replace("<eos>", "")

            label = label[label != self.ignore_id]
            groundtruth = ids_to_str(label, self.token_list)

            groundtruth = groundtruth.replace("▁", " ").strip()
            transcription_v = transcription_v.replace("▁", " ").strip()
            transcription_a = transcription_a.replace("▁", " ").strip()
            transcription_av = transcription_av.replace("▁", " ").strip()

            self.wer_video.update(transcription_v, groundtruth)
            self.wer_audio.update(transcription_a, groundtruth)
            self.wer_av.update(transcription_av, groundtruth)

    def test_step(self, data, batch_idx, dataloader_idx=0):
        lengths = torch.tensor(data["video_lengths"], device=data["video"].device)
        print(lengths)
        padding_mask = make_non_pad_mask(lengths).to(lengths.device)
        self.calculate_wer(
            data["video"].squeeze(1), 
            data["audio"].transpose(1, 2),
            padding_mask, 
            data["label"],
            batch_idx, 
        )

        print(self.wer_video.compute())
        print(self.wer_audio.compute())
        print(self.wer_av.compute())

    def on_validation_epoch_end(self):
        self.log("wer_att_libri", self.wer_att, sync_dist=True)

    def on_test_epoch_end(self):
        wer_video = self.wer_video.compute()
        wer_audio = self.wer_audio.compute()
        wer_av = self.wer_av.compute()
        print(wer_video)
        print(wer_audio)
        print(wer_av)
        self.log("wer_video", wer_video)
        self.log("wer_audio", wer_audio)
        self.log("wer_av", wer_av)
        self.wer_video.reset()
        self.wer_audio.reset()
        self.wer_av.reset()
    
    def on_train_epoch_start(self):
        for loader in self.trainer.train_dataloader.values():
            sampler = loader.batch_sampler
            if hasattr(sampler, "set_epoch"):
                sampler.set_epoch(self.current_epoch)
        return super().on_train_epoch_start()

    def configure_optimizers(self):
        param_groups = get_param_groups(
            self.model.model, 
            self.cfg.model.backbone.elayers, 
            self.cfg.optimizer.base_lr, 
            self.cfg.optimizer.base_lr_other, 
            self.cfg.optimizer.lr_decay_rate,
            min_lr=self.cfg.optimizer.min_lr,
        )

        optimizer = AdamW(
            param_groups, weight_decay=self.cfg.optimizer.weight_decay, betas=self.cfg.optimizer.betas
        )

        warmup_epochs = self.cfg.optimizer.warmup_epochs
        train_len = len(self.trainer.datamodule.train_dataloader()["unlabelled"])
        scheduler = {
            'scheduler': WarmupCosineScheduler(
                optimizer,
                warmup_epochs,
                self.cfg.trainer.max_epochs,
                train_len,
                self.cfg.optimizer.cosine_decay,
                excluded_group=None,
            ),
            'interval': 'step',
            'frequency': 1
        }

        self.momentum_scheduler = instantiate(
            self.cfg.model.momentum_scheduler,
            iter_per_epoch=train_len,
        )

        self.ar_scheduler = ConstantThenLinearScheduler(
            self.cfg.model.ar_start_val, 
            int((1.0 - self.cfg.model.ar_percent_linear)*self.cfg.trainer.max_epochs), 
            self.cfg.trainer.max_epochs, 
            train_len,
        )

        return [optimizer], [scheduler]

