import math

from config.base_config import Config
import numpy as np
import torch
from collections import defaultdict, deque
from trainer.base_trainer import BaseTrainer
from modules.metrics import sim_matrix_training
from modules.metrics import *
from modules.metrics import sim_matrix_training, sim_matrix_inference, generate_embeds_per_video_id
from tqdm import tqdm
from config.all_config import gen_log
import time
import gc
import multiprocessing
import torch.multiprocessing as mp
from multiprocessing import Pool, get_start_method,set_start_method


class Trainer(BaseTrainer):
    """
    Trainer class
    Note:
        Inherited from BaseTrainer.
    """

    def __init__(self, model, loss,  metrics, optimizer, pretrain_optimizer, config: Config, train_data_loader,
                 valid_data_loader, tokenizer, lr_scheduler=None, writer=None):

        super().__init__(model, loss, metrics, optimizer, config, writer)
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.lr_scheduler = lr_scheduler
        self.tokenizer = tokenizer 

        self.pooling_type = config.pooling_type
        self.window_metric = defaultdict(lambda: deque(maxlen=config.eval_window_size))
        self.best_window = -1.0
        self.best = -1.0

        self.pretrain_optimizer = pretrain_optimizer

    def _pretrain_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        """
        gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=f'\n=============[pre-train epo={epoch}]=============\n')

        self.model.train()
        total_loss = 0.0
        num_steps = len(self.train_data_loader)
        eval_steps = np.linspace(0, num_steps - 1, self.evals_per_epoch + 1, dtype=int)[1:]

        for batch_idx, data in enumerate(self.train_data_loader):
            # then assume we must tokenize the input, e.g. its a string
            if self.tokenizer is not None:
                data['text'] = self.tokenizer(data['text'], return_tensors='pt', padding=True,
                                              truncation=True)
            if isinstance(data['text'], torch.Tensor):
                data['text'] = data['text'].to(self.device)
            else:
                data['text'] = {key: val.to(self.device) for key, val in data['text'].items()}

            data['video'] = data['video'].to(self.device)

            text_embeds, video_embeds, video_embeds_pooled, text_embeds_stochstic, text_mean, log_var, dm_loss = self.model(data, no_aligned_embed=True)
            # print(f'>>>[_pretrain_epoch] text_embeds={text_embeds.shape}, video_embeds={video_embeds.shape}, video_embeds_pooled={video_embeds_pooled.shape}')
            # print(f'>>>[_pretrain_epoch] text_embeds_stochstic={text_embeds_stochstic.shape}, dm_loss={dm_loss.shape}')


            dm_loss.backward()


            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.pretrain_optimizer.step()
            self.pretrain_optimizer.zero_grad()
            torch.clamp_(self.model.clip.logit_scale.data, max=np.log(100))

            self.global_step += 1

            total_loss += dm_loss.detach().item()

            if self.config.noloss_record:
                pass
            else:
                gen_log(model_path=self.config.model_path, log_name='log_dm_loss',
                        msg=dm_loss.item())

            if batch_idx % self.log_step == 0:
                msg = ('Pre-Train Epoch: {} dl: {}/{} Diffusion Loss: {:.6f}'.format(
                    epoch,
                    batch_idx,
                    num_steps - 1,
                    dm_loss.detach().item(),
                ))
                gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)


            if batch_idx in eval_steps:
                if self.config.skip_eval:
                    msg = '\nSkip eval due to long time usage!\n'
                    gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

                else:
                    val_res = self._valid_epoch_step(epoch, batch_idx, num_steps - 1)
                    self.model.train()

                    if val_res['R1-window'] > self.best_window:
                        self.best_window = val_res['R1-window']


                    if val_res['R1'] > self.best:
                        self.best = val_res['R1']
                        self._save_checkpoint(epoch, save_best=True)

                    msg = (" [Pre-Train] Current Best Window Average R@1 is {}".format(self.best_window),
                           " [Pre-Train] Current Best R@1 is {}\n\n".format(self.best))
                    gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

        res = {
            'loss_train': total_loss / num_steps
        }

        return res

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        """
        gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=f'\n=============[train epo={epoch}]=============\n')
        self.model.train()
        total_loss = 0.0
        num_steps = len(self.train_data_loader)
        eval_steps = np.linspace(0, num_steps-1, self.evals_per_epoch+1, dtype=int)[1:]

        for batch_idx, data in enumerate(self.train_data_loader):
            # then assume we must tokenize the input, e.g. its a string
            if self.tokenizer is not None:
                data['text'] = self.tokenizer(data['text'], return_tensors='pt', padding=True,
                                              truncation=True)
            if isinstance(data['text'], torch.Tensor):
                data['text'] = data['text'].to(self.device)
            else:
                data['text'] = {key: val.to(self.device) for key, val in data['text'].items()}
            
            data['video'] = data['video'].to(self.device)

            text_embeds, video_embeds, video_embeds_pooled, text_embeds_stochstic, text_mean, log_var, dm_loss = self.model(data, no_aligned_embed=True)
            # print(f'>>>[_pretrain_epoch] text_embeds={text_embeds.shape}, video_embeds={video_embeds.shape}, video_embeds_pooled={video_embeds_pooled.shape}')
            # print(f'>>>[_pretrain_epoch] text_embeds_stochstic={text_embeds_stochstic.shape}, dm_loss={dm_loss.shape}')

            text_embeds_DMalign = self.model.text_cond_processor(text_embeds_stochstic, video_embeds_pooled, no_aligned_embed=False)

            output = sim_matrix_training(text_embeds_DMalign, video_embeds_pooled, self.pooling_type)
            loss = self.loss(output, self.model.clip.logit_scale)

            loss_all = loss + self.config.dm_loss_weight * dm_loss
            loss_all.backward()


            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()

            torch.clamp_(self.model.clip.logit_scale.data, max=np.log(100))

            self.global_step += 1
            total_loss += loss_all.detach().item()

            if self.config.noloss_record:
                pass
            else:
                gen_log(model_path=self.config.model_path, log_name='log_ori_loss',
                        msg=loss.item())
                gen_log(model_path=self.config.model_path, log_name='log_dm_loss',
                        msg=dm_loss.item())

            if batch_idx % self.log_step == 0:
                msg = ('Train Epoch: {} dl: {}/{} Original Loss: {:.6f},  Diffusion Loss: {:.6f}'.format(
                    epoch,
                    batch_idx,
                    num_steps-1,
                    loss.detach().item(),
                    dm_loss.detach().item(),
                ))
                gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

            if epoch >= self.config.start_eval_epo:
                print(f'epoch={epoch}, self.config.start_eval_epo={self.config.start_eval_epo}')
                if batch_idx in eval_steps:

                    if self.config.skip_eval:
                        msg = '\nSkip eval due to long time usage!\n'
                        gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

                    else:
                        val_res = self._valid_epoch_step(epoch, batch_idx, num_steps-1)
                        self.model.train()

                        if val_res['R1-window'] > self.best_window:
                            self.best_window = val_res['R1-window']
                            self._save_checkpoint(epoch, save_best=True)


                        if val_res['R1'] > self.best:
                            self.best = val_res['R1']

                        msg = ("[Train] Current Best Window Average R@1 is {}".format(self.best_window),
                               "[Train] Current Best R@1 is {}\n\n".format(self.best))
                        gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

        res = {
            'loss_train':  total_loss / num_steps
        }

        return res

    def _valid_epoch_step(self, epoch, step, num_steps):
        """
        Validate at a step when training an epoch at a certain step
        :return: A log that contains information about validation
        """
        self.model.eval()
        total_val_loss = 0.0
        text_embed_arr = []
        vid_embed_arr = []
        all_vid_ids = []

        with torch.no_grad():
            for _, data in tqdm(enumerate(self.valid_data_loader)):
                if self.tokenizer is not None:
                    data['text'] = self.tokenizer(data['text'], return_tensors='pt', padding=True, truncation=True)
                if isinstance(data['text'], torch.Tensor):
                    data['text'] = data['text'].to(self.device)
                else:
                    data['text'] = {key: val.to(self.device) for key, val in data['text'].items()}

                data['video'] = data['video'].to(self.device)

                text_embed, vid_embed, *_ = self.model(data, no_aligned_embed=False)

                text_embed_arr.append(text_embed.cpu())
                vid_embed_arr.append(vid_embed.cpu())

                # sims_batch = sim_matrix_training(text_embed, vid_embed_pooled, self.pooling_type)
                #
                # curr_loss = self.loss(sims_batch, self.model.clip.logit_scale)
                # total_val_loss += curr_loss.item()

                for v_id in data['video_id']:
                    all_vid_ids.append(v_id)

            text_embeds = torch.cat(text_embed_arr)
            vid_embeds = torch.cat(vid_embed_arr)

            # Since we have all pairs, remove duplicate videos when there's multiple captions per video
            vid_embeds_per_video_id = {}
            for idx, v_id in enumerate(all_vid_ids):
                if v_id not in vid_embeds_per_video_id:
                    vid_embeds_per_video_id[v_id] = vid_embeds[idx]

            vid_embeds = torch.stack([vid_embeds_per_video_id[v_id] for v_id in vid_embeds_per_video_id])

            # Pool frames for inference once we have all texts and videos
            self.model.pool_frames.cpu()
            vid_embeds_pooled = self.model.pool_frames(text_embeds, vid_embeds)
            self.model.pool_frames.cuda()
            # print(f'>>>vid_embeds_pooled={vid_embeds_pooled.shape}') # [bs, bs, 512]

            start_time = time.time()
            text_embeds_stochastic_allpairs = torch.zeros(
                size=(vid_embeds.shape[0], text_embeds.shape[0], text_embeds.shape[1]))
            # for idx_txt, single_text in enumerate(text_embeds):
            for (idx_vid, single_vid), single_vid_embed_pooled in tqdm(zip(enumerate(vid_embeds), vid_embeds_pooled)):
                single_vid_vec = single_vid.unsqueeze(0)
                single_vid_repeat = single_vid_vec.tile((text_embeds.shape[0], 1, 1))
                # print(f'>>>single_vid_repeat={single_vid_repeat.shape}') # [bs_t, 12, 512]
                # print(f'>>>single_vid_embed_pooled={single_vid_embed_pooled.shape}') # [bs_t, 512]
                all_text_embed_stochstic = []

                for trial in range(self.config.stochasic_trials):
                    all_text_embed_stochastic, _, _ = self.model.stochastic(text_embeds, single_vid_repeat)
                    # print(f'>>>all_text_embed_stochastic={all_text_embed_stochastic.shape}') # [bs_t, 512]
                    all_text_embed_stochastic_DMalign = self.model.text_cond_processor(all_text_embed_stochastic.cuda(),
                                                                                    video_features_pooled=None,
                                                                                    no_aligned_embed=False)
                    # print(f'>>>[_valid_epoch_step] all_text_embed_stochastic_DMalign={all_text_embed_stochastic_DMalign.shape}') # [1000, 512]
                    # gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=f'finish text_cond_processor')
                    all_text_embed_stochstic.append(all_text_embed_stochastic_DMalign)
                # text_embeds_stochastic_allpairs[:, idx_vid] = torch.mean(torch.stack(all_text_embed_stochstic, dim=0), dim=0)
                all_text_embed_stochstic_arr = torch.stack(all_text_embed_stochstic, dim=0).cpu()
                # print(f'>>>all_text_embed_stochstic_arr={all_text_embed_stochstic_arr.shape}, single_vid_repeat={single_vid_repeat.shape}') # [#trials, bs_t, 512] [bs_t, 12, 512]
                # normalization before compute cos-sim
                all_text_embed_stochstic_arr = all_text_embed_stochstic_arr / all_text_embed_stochstic_arr.norm(dim=-1,
                                                                                                                keepdim=True)
                single_vid_embed_pooled = single_vid_embed_pooled / single_vid_embed_pooled.norm(dim=-1, keepdim=True)
                # compute cos-sim
                sim_select = torch.sum(torch.mul(all_text_embed_stochstic_arr, single_vid_embed_pooled), dim=-1)
                # print(f'>>>sim_select={sim_select.shape}') # [#trial, bs_t]
                # find max cos, take idx
                max_indices = torch.argmax(sim_select, dim=0)
                # print(f'>>>max_indicies={max_indices}') # verified <trials
                # print(f'>>>check max_indices={max_indices.shape}') # [bs_t]
                # select based on the idx
                selected_plane = torch.ones(
                    (all_text_embed_stochstic_arr.shape[1], all_text_embed_stochstic_arr.shape[2]))
                for i in range(all_text_embed_stochstic_arr.shape[1]):
                    selected_plane[i, :] = all_text_embed_stochstic_arr[max_indices[i], i, :]
                # print(f'>>>text_embeds_stochastic_allpairs[idx_vid,:,:]={text_embeds_stochastic_allpairs[idx_vid,:,:].shape}, selected_plane={selected_plane.shape}') # both [bs_t, 512]
                text_embeds_stochastic_allpairs[idx_vid, :, :] = selected_plane

            # print(f'>>>check text_embeds_stochastic_allpairs={text_embeds_stochastic_allpairs.shape}') #
            end_time = time.time()
            msg = (
                f'To compute all stochastic-DM-aligned text embeddings for the whole dataset, the time usage is {end_time - start_time}\n')
            gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

            del text_embeds, vid_embeds
            gc.collect()

            if self.config.save_VT_embed_for_plot:
                print(f'>>>check dim: text_embeds_stochastic_allpairs={text_embeds_stochastic_allpairs.shape}, vid_embeds_pooled={vid_embeds_pooled.shape}')
                import os
                np.save(os.path.join(self.config.model_path, 'resulting_txt_embedding.npy'), {'res': text_embeds_stochastic_allpairs})
                np.save(os.path.join(self.config.model_path, 'resulting_vid_embedding.npy'), {'res': vid_embeds_pooled})


            start_time_gen = time.time()
            text_embeds_per_video_id, vid_embeds_pooled_per_video_id = generate_embeds_per_video_id(text_embeds_stochastic_allpairs,vid_embeds_pooled, all_vid_ids, self.pooling_type)
            end_time_gen = time.time()
            gen_log(model_path=self.config.model_path, log_name='log_trntst',msg=f'generate_embeds_per_video_id_stochastic() time usage={end_time_gen - start_time_gen}')
            msg = (f'>>> check text_embeds_per_video_id={text_embeds_per_video_id.shape}, vid_embeds_pooled_per_video_id={vid_embeds_pooled_per_video_id.shape}')
            # [bs_v, bs_v, X, 512] [bs_v, bs_v, X, 512]
            gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

            del text_embeds_stochastic_allpairs, vid_embeds_pooled
            gc.collect()


            gen_log(model_path=self.config.model_path, log_name='log_trntst',
                    msg='Use sim_matrix_inference_stochastic()')
            sims = sim_matrix_inference(text_embeds_per_video_id, vid_embeds_pooled_per_video_id,self.pooling_type)

            total_val_loss = total_val_loss / len(self.valid_data_loader)



            if self.config.save_sims_for_plot:
                print(f'>>>sims={type(sims)}, shape={sims.shape}')
                import os
                np.save(os.path.join(self.config.model_path, 'save_sims_tmass.npy'), {'res': sims})


            metrics = self.metrics
            res = metrics(sims)

            # Compute window metrics
            for m in res:
                self.window_metric[m].append(res[m])

            # Compute average of window metrics
            for m in self.window_metric:
                res[m + "-window"] = np.mean(self.window_metric[m])

            msg = (f"-----Val Epoch: {epoch}, dl: {step}/{num_steps}-----\n",
                   f"R@1: {res['R1']} (window: {res['R1-window']})\n",
                   f"R@5: {res['R5']} (window: {res['R5-window']})\n",
                   f"R@10: {res['R10']} (window: {res['R10-window']})\n",
                   f"MedR: {res['MedR']} (window: {res['MedR-window']})\n",
                   f"MeanR: {res['MeanR']} (window: {res['MeanR-window']})\n",
                   f"Loss: {total_val_loss}")
            gen_log(model_path=self.config.model_path, log_name='log_trntst', msg=msg)

            res['loss_val'] = total_val_loss

            # if self.writer is not None:
            #     for m in res:
            #         self.writer.add_scalar(f'val/{m}', res[m], self.global_step)

            return res


