import copy
import os
import numpy as np
import torch
from pytorch_lightning.core import LightningModule
from torch.optim.lr_scheduler import CosineAnnealingLR
from models.embedder import Embedder
from models.loss import AlignmentLoss
from dataset.video_align_dataset import VideoAlignmentTrainDataset
from utils.evaluation import get_kendalls_tau
from evaluation.evalulate_features import prepare_data_loader, extract_embedding
from evaluation.classification import classification
from evaluation.event_completion import progression
from evaluation.frame_retrieval import frame_retrieval
from evaluation.kendalls_tau import kendalls_tau


class VideoAlignment(LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.model = Embedder(args)
        self.loss = AlignmentLoss(args)
        self.checkpoint_metric = "train_loss"
        self.data_path = None
        self.object_box = True if 'bbox' in args.task else False
        self.cls_loader_train, self.ds_dataset_train = prepare_data_loader(args, 'train', batch_size=256,
                                                                           bbox=self.object_box)
        self.cls_loader_val, self.ds_dataset_val = prepare_data_loader(args, args.train_eval_mode, batch_size=256,
                                                                       bbox=self.object_box)
        print(f'constructing cls loader train {len(self.cls_loader_train)} | val {len(self.cls_loader_val)}')

    def evaluate_downstream(self, batch_idx, device):
        if self.args.cls_every_n_epoch > 0 and self.global_rank == 0 and batch_idx == 0 and (
                self.current_epoch + 1) % self.args.cls_every_n_epoch == 0:
            extract_embedding('train', self.cls_loader_train, self.model, self.args.output_dir, device,
                              label_all=self.args.label_all, object_box=self.object_box)
            extract_embedding('val', self.cls_loader_val, self.model, self.args.output_dir, device,
                              label_all=self.args.label_all, object_box=self.object_box)
            if self.current_epoch + 1 == self.args.epochs:  # last epoch
                self.args.eval_task = '01234'

            if '0' in self.args.eval_task:
                train_tau = kendalls_tau(self.args.output_dir, self.ds_dataset_train.video_len_list,
                                         self.ds_dataset_train.video_paths1, 'train', False)
                val_tau = kendalls_tau(self.args.output_dir, self.ds_dataset_val.video_len_list,
                                       self.ds_dataset_val.video_paths1, 'val', False)
                self.log('train_tau', train_tau)
                self.log('val_tau', val_tau)

            if '1' in self.args.eval_task:
                train_acc, val_acc, val_f1 = classification(self.args.output_dir, self.args.label_all, cls=True,
                                                            few_shot=False)
                self.log('train_acc', train_acc)
                self.log('val_acc', val_acc)
                self.log('val_f1', val_f1)

            if '2' in self.args.eval_task:
                train_acc, val_acc, val_f1 = classification(self.args.output_dir, self.args.label_all, cls=False,
                                                            few_shot=True)
                self.log('fs_train_acc', train_acc)
                self.log('fs_val_acc', val_acc)
                self.log('fs_val_f1', val_f1)

            if '3' in self.args.eval_task:
                map_5, map_10, map_15 = frame_retrieval(self.args.output_dir, self.ds_dataset_val.video_len_list,
                                                        self.ds_dataset_val.video_paths1)
                self.log('map_5', map_5)
                self.log('map_10', map_10)
                self.log('map_15', map_15)

            if '4' in self.args.eval_task:
                train_score, val_score = progression(self.args.output_dir, self.ds_dataset_train.video_len_list,
                                                     self.ds_dataset_val.video_len_list)
                self.log('train_score', train_score)
                self.log('val_score', val_score)

            if '7' in self.args.eval_task:
                train_acc, val_acc, val_f1 = classification(self.args.output_dir, self.args.label_all, cls=True,
                                                            few_shot=False, modify_labels=True)
                self.log('modify_label_train_acc', train_acc)
                self.log('modify_label_val_acc', val_acc)
                self.log('modify_label_val_f1', val_f1)

                train_acc, val_acc, val_f1 = classification(self.args.output_dir, label_all=self.args.label_all,
                                                            cls=True, few_shot=False, modify_embeddings=True,
                                                            train_video_len_list=self.ds_dataset_train.video_len_list,
                                                            val_video_len_list=self.ds_dataset_val.video_len_list)
                self.log('modify_embs_train_acc', train_acc)
                self.log('modify_embs_val_acc', val_acc)
                self.log('modify_embs_val_f1', val_f1)

            if '8' in self.args.eval_task:
                train_score, val_score = progression(self.args.output_dir, self.ds_dataset_train.video_len_list,
                                                     self.ds_dataset_val.video_len_list, self.args.label_all,
                                                     modify_embeddings=True)
                self.log('modify_embs_train_score', train_score)
                self.log('modify_embs_val_score', val_score)

    def training_step(self, batch, batch_idx):
        frames, steps, seq_lens = batch
        x1 = frames[:, 0, ...].permute(0, 1, 4, 2, 3)  # (bs, 64, 3, 168, 168)
        x2 = frames[:, 1, ...].permute(0, 1, 4, 2, 3)

        embeds1 = self.model(x1)
        embeds2 = self.model(x2)
        embeddings = torch.stack((embeds1, embeds2), dim=1)  # (bs, 2, 32, 128)

        loss, loss2 = self.loss(embeddings, steps, seq_lens, self.global_step)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_loss_reg', loss2, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        frames, steps, seq_lens = batch
        x1 = frames[:, 0, ...].permute(0, 1, 4, 2, 3)  # (bs, 64, 3, 168, 168)
        x2 = frames[:, 1, ...].permute(0, 1, 4, 2, 3)

        embeds1 = self.model(x1)
        embeds2 = self.model(x2)
        embeddings = torch.stack((embeds1, embeds2), dim=1)  # (bs, 2, 32, 128)

        loss, loss2 = self.loss(embeddings, steps, seq_lens, self.global_step)
        self.log('val_loss', loss, on_step=True, on_epoch=True)
        self.log('val_loss_reg', loss2, on_step=True, on_epoch=True)

        self.evaluate_downstream(batch_idx, embeddings.device)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr, weight_decay=self.args.wd)
        if not self.args.use_scheduler:
            return optimizer
        else:
            scheduler = CosineAnnealingLR(optimizer, T_max=self.args.epochs, eta_min=self.args.min_lr)
            return [optimizer], [scheduler]

    def train_dataloader(self):
        dataset = VideoAlignmentTrainDataset(self.args, 'train')
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.args.batch_size,
                                             num_workers=self.args.num_workers)
        return loader

    def val_dataloader(self):
        dataset = VideoAlignmentTrainDataset(self.args, 'val')
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=self.args.batch_size,
                                             num_workers=self.args.num_workers)
        return loader
