import os
import time
import math
import random
import numpy as np

import torch
import torch.nn.functional as F

from tqdm import tqdm

from read_args import get_args
from datasets import davis17, movi, ytvis19
from utils.eval_utils import Evaluator
from model.model import Slot_Attention_Auto_Encoder
from model.helper_modules import DINO

from utils.utils import bipartiate_match_video

class Trainer:
    def __init__(self, args):
        self.args = args
        self.device = args.device
        self.device_type = "cuda" if self.device != "cpu" else "cpu"

        self.fix_seed()
        self.init_dataloader()
        self.init_train_elements()
        self.args.logger.info("Initialized")

    def fix_seed(self, seed=None):
        if seed is None:
            seed = self.args.seed

        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def init_dataloader(self):

        if self.args.dataset == "davis19":
            train_dataset = davis17.DAVIS17_train(self.args)
            val_dataset = davis17.DAVIS17_val(self.args)

        elif self.args.dataset == "ytvis19":
            train_dataset = ytvis19.YTVIS_train(self.args)
            val_dataset = ytvis19.YTVIS_val(self.args)

        elif self.args.dataset == "movi":
            train_dataset = movi.MOVi_train(self.args)
            val_dataset = movi.MOVi_val(self.args)

        self.train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=self.args.batch_size,
            shuffle=True, num_workers=5*self.args.workers, drop_last=True, pin_memory=True)

        self.val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=1,
            shuffle=False, num_workers=5*self.args.workers, drop_last=False, pin_memory=True)

    def save_model(self, epoch):
        path = os.path.join(self.args.model_save_path, "checkpoint.pt")

        checkpoint = {"epoch": epoch,
                      "state_dict": self.model.state_dict(),
                      "optimizer": self.optimizer.state_dict(),
                      "scheduler": self.scheduler.state_dict()}
        torch.save(checkpoint, path)

        if (epoch + 1) % 5 == 0:
            path = os.path.join(self.args.model_save_path, f"checkpoint_epoch_{epoch}.pt")
            torch.save(checkpoint, path)

    def init_train_elements(self):
        if self.args.use_checkpoint:
            checkpoint = torch.load(self.args.checkpoint_path)

        self.args.logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
        # === Model ===
        self.dino = torch.nn.DataParallel(DINO(self.args)).to(self.device)
        self.dino.eval()

        self.model = torch.nn.DataParallel(Slot_Attention_Auto_Encoder(self.args)).to(self.device)
        if self.args.use_checkpoint:
            self.model.load_state_dict(checkpoint["state_dict"], strict=True)

        params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        self.args.logger.info(f"Number of Trainable Parameters: {params/1e6:.2f}M")

        # === Optimizer ===
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        if self.args.use_checkpoint and not self.args.finetuning:
            self.optimizer.load_state_dict(checkpoint["optimizer"])

        # === LR Scheduler ===
        T_max = len(self.train_loader) * self.args.num_epochs
        self.args.logger.info(f"{T_max} training iterations")

        warmup_steps = int(T_max * 0.05)
        steps = T_max - warmup_steps
        gamma = math.exp(math.log(0.5) / (steps // 3))        # in 1 / 3 of all steps, lr is halfed

        # warmup steps first
        linear_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=1e-5, total_iters=warmup_steps)

        scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=gamma)

        self.scheduler = torch.optim.lr_scheduler.SequentialLR(self.optimizer, 
            schedulers=[linear_scheduler, scheduler], milestones=[warmup_steps])
        if self.args.use_checkpoint and not self.args.finetuning:
            self.scheduler.load_state_dict(checkpoint["scheduler"])

        # === Start Epoch ===
        self.start_epoch = 0
        if self.args.use_checkpoint and not self.args.finetuning:
            self.start_epoch = checkpoint["epoch"] + 1

        # === Evaluators ===
        self.evaluator = Evaluator()

    def train_epoch(self):

        total_loss = 0.0

        self.model.train()
        train_loader = tqdm(self.train_loader)

        for i, (frames, masks) in enumerate(train_loader):
                                                                                # F = 2N + 1
            frames = frames.to(self.device, non_blocking=True)                  # (B, F, 3, H, W)
            masks = masks.to(self.device, non_blocking=True)                    # (B, F)

            B = frames.shape[0]

            if self.args.calculate_memory:
                torch.cuda.empty_cache()
                mem_allocated_before = torch.cuda.memory_allocated(self.device)

            # === === DINO === ===
            with torch.autocast(device_type=self.device_type, dtype=torch.float16):
                output_features, _ = self.dino(frames[:, [self.args.N]], reduce=False) # (B, token_num, 768)
                dropped_features, token_indices = self.dino(frames)                    # (B * F, token_reduced, 768)

                assert output_features.isnan().any() == False, f"{torch.sum(output_features.isnan())} items are NaN"
                assert dropped_features.isnan().any() == False, f"{torch.sum(dropped_features.isnan())} items are NaN"

            output_features = output_features.to(torch.float32)
            dropped_features = dropped_features.to(torch.float32)
            # === === === === ===

            # === === UTMOST === ===
            reconstruction = self.model(dropped_features, masks, token_indices)
            # === === === === ===

            # === === Loss and Updates === ===
            loss = F.mse_loss(reconstruction["rec"], output_features)

            total_loss += loss.item()

            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()

            if self.args.calculate_memory:
                mem_allocated_after = torch.cuda.memory_allocated(self.device)
                mem_usage = (mem_allocated_after - mem_allocated_before) / (1024 ** 2)
                self.args.logger.info(f"Memory usage for one forward + backward pass: {mem_usage:.2f} MB")
                return None

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            # === === === === ===
            
            # === Logger ===
            loss_desc = f"lr: {self.optimizer.state_dict()['param_groups'][0]['lr']:.6f} | loss: {total_loss / (i + 1):.5f}"
            train_loader.set_description(loss_desc)
            # === === ===

            # === Tensorboard (Batch Log) ===
            # self.args.writer.add_scalar("batch/loss", loss.item(), self.total_iter)

            self.total_iter += 1
            # === === ===

        mean_loss = total_loss / (i + 1)
        return mean_loss

    @torch.no_grad()
    def val_epoch(self, epoch):

        self.model.eval()
        val_loader = tqdm(self.val_loader)

        # bs = int(self.args.batch_size * (1 - self.args.token_drop_ratio) ** 2)
        bs = self.args.batch_size // self.args.workers
        ps = self.args.patch_size

        for i, (model_input, input_masks, gt_masks) in enumerate(val_loader):

            model_input = model_input.to(self.device, non_blocking=True)                # (1, #frames + 2N, 3, H, W)
            input_masks = input_masks.to(self.device, non_blocking=True)                # (1, #frames + 2N)
            gt_masks = gt_masks.to(self.device, non_blocking=True)                      # (1, #frames, #objects, H_t, W_t)

            H_t, W_t = gt_masks.shape[-2:]
            frame_num = gt_masks.shape[1]

            H, W = self.args.resize_to
            turn_number = model_input.shape[1] // bs
            if model_input.shape[1] % bs != 0:
                turn_number += 1

            # === DINO feature extraction ===
            all_dino_features = []
            all_token_indices = []
            for j in range(turn_number):
                s = j * bs
                e = (j + 1) * bs
                with torch.autocast(device_type=self.device_type, dtype=torch.float16):
                    features, token_indices = self.dino.module(model_input[:, s:e], reduce=False)     # (bs, token_num, 768), (bs, token_num)
                    assert features.isnan().any() == False, f"{torch.sum(features.isnan())} items are NaN"

                all_dino_features.append(features.to(torch.float32))
                all_token_indices.append(token_indices)

            all_dino_features = torch.cat(all_dino_features, dim=0)                        # (#frames + 2N, token_num, 768)
            all_token_indices = torch.cat(all_token_indices, dim=0)                        # (#frames + 2N, token_num)

            all_model_inputs = []
            all_model_tokens = []
            all_masks_input = []
            for i in range(frame_num):
                indices = list(range(i, i + (2 * self.args.N + 1)))
                all_model_inputs.append(all_dino_features[indices].unsqueeze(dim=0))      # (1, 2N + 1, token_num, 768)
                all_model_tokens.append(all_token_indices[indices].unsqueeze(dim=0))      # (1, 2N + 1, token_num)
                all_masks_input.append(input_masks[:, indices])                           # (1, 2N + 1)

            all_model_inputs = torch.cat(all_model_inputs, dim=0)                         # (#frames, 2N + 1, token_num, 768)
            all_model_tokens = torch.cat(all_model_tokens, dim=0)                         # (#frames, 2N + 1, token_num)
            all_masks_input = torch.cat(all_masks_input, dim=0)                           # (#frames, 2N + 1)
            # === === ===

            turn_number = frame_num // bs
            if frame_num % bs != 0:
                turn_number += 1
                
            out_masks = []
            all_slots = []
            all_slot_nums = []
            for j in range(turn_number):
                s = j * bs
                e = (j + 1) * bs

                # === Input features ===
                features = all_model_inputs[s:e]                    # (bs, 2N + 1, token_num, 768)
                # output_features = features[:, self.args.N]          # (bs, token_num, 768)
                features = torch.flatten(features, 0, 1)            # (bs * (2N + 1), token_num, 768)

                # === Token indices ===
                token_indices = all_model_tokens[s:e]               # (bs, 2N + 1, token_num)
                token_indices = torch.flatten(token_indices, 0, 1)  # (bs * (2N + 1), token_num)

                # === Attention masks ===
                input_masks_j = all_masks_input[s:e]

                reconstruction = self.model.module(features, input_masks_j, token_indices)

                masks = reconstruction["mask"]                                              # (bs, S, token)
                slots = reconstruction["slots"]                                             # (bs, S, D_slot)
                slot_nums = reconstruction["slot_nums"]                                     # (bs)
                for l in range(slot_nums.shape[0]):
                    slot_num = slot_nums[l]
                    slots_l = slots[l, :slot_num]                                           # (S', D_slot)
                    all_slots.append(slots_l)

                out_masks.append(masks)
                all_slot_nums.append(slot_nums)

            all_slots = torch.cat(all_slots, dim=0)                                         # (#slots, D_slot)
            all_slot_nums = torch.cat(all_slot_nums, dim=0)                                 # (#frames)
            masks = torch.cat(out_masks, dim=0)                                             # (#frames, S, token)

            S = masks.shape[1]

            masks = masks.view(-1, S, H // ps, W // ps)                                     # (#frames, S, H // 8, W // 8)
            predictions = F.interpolate(masks, size=(H_t, W_t), mode="bilinear")            # (#frames, S, H_t, W_t)
            predictions = torch.argmax(predictions, dim=1)                                  # (#frames, H_t, W_t)
            
            # if self.args.merge_slots:
            predictions = bipartiate_match_video(all_slots, all_slot_nums, predictions)

            # === Instance Segmentation Evaluation ===
            miou, _, _, _ = self.evaluator.update(predictions, gt_masks[0])
            loss_desc = f"mIoU: {miou:.5f}"

            # === Logger ===
            val_loader.set_description(loss_desc)
            # === === ===

        # === Evaluation Results ====
        miou, fg_miou, fg_ari, miou_per_video = self.evaluator.get_results()

        # === Logger ===
        self.args.logger.info("\n=== Results ===")
        # self.args.logger.info(f"\tmIoU: {miou:.5f}")
        self.args.logger.info(f"\tmIoU: {fg_miou:.5f}")
        self.args.logger.info(f"\tFG-ARI: {fg_ari:.5f}")
        # self.args.logger.info(f"\tmIoU_per_video: {miou_per_video:.5f}\n")

        # === Tensorboard ===
        self.args.writer.add_scalar("multi_object/mIoU", miou, epoch)
        self.args.writer.add_scalar("multi_object/FG-mIoU", fg_miou, epoch)
        self.args.writer.add_scalar("multi_object/FG-ARI", fg_ari, epoch)

        return miou, fg_miou, fg_ari


    def validate(self):

        mious = []
        fg_aris = []

        self.model.module.slot_cluster.cluster_drop_p = 0
        for k in range(self.args.validate_k_times):
            self.init_dataloader()

            _, miou, fg_ari = self.val_epoch(k)
            mious.append(miou * 100)
            fg_aris.append(fg_ari* 100)

        miou_mean = np.mean(mious)
        miou_std = np.std(mious)

        fg_ari_mean = np.mean(fg_aris)
        fg_ari_std = np.std(fg_aris)

        self.args.logger.info("\n=== Final Results ===")
        self.args.logger.info(f"mIoU mean: {miou_mean:.4f} std: {miou_std:.4f}")
        self.args.logger.info(f"FG-ARI mean: {fg_ari_mean:.4f} std: {fg_ari_std:.4f}")


    def train(self):
        self.total_iter = 0


        for epoch in range(self.start_epoch, self.args.num_epochs):
            self.args.logger.info(f"===== ===== [Epoch {epoch}] ===== =====")

            # === === === Training === === ===
            self.model.module.slot_cluster.cluster_drop_p = 1 - (math.log(epoch + 1) / math.log(self.args.num_epochs))
            train_mean_loss = self.train_epoch()
            if self.args.calculate_memory:
                return
            self.save_model(epoch)
            # === === === === === === ===
            
            # === === === Validation === === ===
            self.model.module.slot_cluster.cluster_drop_p = 0
            if (epoch == 0) or ((epoch + 1) % self.args.validation_epoch == 0):
                self.val_epoch(epoch)
            # === === === === === === ===

            # === Tensorboard (Epoch Log) ===
            self.args.writer.add_scalar("epoch/train-loss", train_mean_loss, epoch)
            
            self.args.writer.flush()
            self.args.writer.close()
            # === === === === === === ===
            self.args.logger.info(f"===== ===== ===== ===== =====\n")

    
    def main(self):
        if self.args.validate:
            self.validate()
        else:
            self.train()


if __name__ == "__main__":
    args = get_args()
    trainer = Trainer(args)
    trainer.main()
