import logging
from logging import Logger
import os
import os.path as osp
import torch

from util.train_utils import find_last_checkpoint_path
from torch.utils.data import DataLoader

from omegaconf import DictConfig, OmegaConf
import hydra
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import seed_everything
# from fvcore.nn import flop_count_table, FlopCountAnalysis
from pytorch_lightning.profilers import PyTorchProfiler
import datetime
import torch.nn as nn
from tasks.tuab_pretrain import TuabPretrainTask
#from lightning.pytorch.profilers import SimpleProfiler, AdvancedProfiler

OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)

logger: Logger = logging.getLogger(__name__)

torch.set_float32_matmul_precision('high')


def train(cfg: DictConfig):
    seed_everything(cfg.seed)

    # tensorboard
    tb_logger = TensorBoardLogger(save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=cfg.io.version)

     # Pytorch Lightning moodule
    print("===> Start building model")
    model = hydra.utils.instantiate(cfg.task, cfg)
    print('\n\nMODEL INITIATED', model)
    # Load pretrained checkpoint
    if cfg.pretrained_checkpoint is not None:
        print(f"===> Loading pretrained_checkpoint from {cfg.pretrained_checkpoint}")
        model = model.load_from_checkpoint(checkpoint_path=cfg.pretrained_checkpoint)
    
    # DataLoader
    print("===> Loading datasets")
    data_module = hydra.utils.instantiate(cfg.data_module)

    # Resume training from saved checkpoint
    checkpoint_dirpath = f"{tb_logger.log_dir}/checkpoints"
    print(f"Checkpoint path: {checkpoint_dirpath}")
    last_ckpt = None
    if cfg.resume:
        last_ckpt = find_last_checkpoint_path(checkpoint_dirpath)
   
    print("===> Instantiate callbacks")
    callbacks = [hydra.utils.instantiate(callback) for _, callback in cfg.callbacks.items()]

    #print("===> Instantiate profiler")
    #profiler = hydra.utils.instantiate(cfg.profiler)

    if cfg.finetune_pretrained:
        print("===> Finetune from pretrained model")
        pretrainmodel = TuabPretrainTask.load_from_checkpoint(checkpoint_path=cfg.finetune_pretrained_path)
        model.model = pretrainmodel.model
        if(cfg.model.ch):
            # Here the d_model is C*D
            d_model = cfg.model.embed_dim * cfg.model.num_channels
        else:
            # Here the d_model is S*D
            d_model = cfg.model.embed_dim * cfg.model.num_channels
        model.model.classifier = nn.Linear(d_model, cfg.num_classes_finetune)
        if(cfg.freeze_layers):
            print("===> Freezing layers")
            for param in model.model.parameters():
                param.requires_grad = False
            for param in model.model.classifier.parameters():
                param.requires_grad = True

    if cfg.freeze_decoder:
        print("===> Freezing decoder")
        for param in model.model.decoder.parameters():
            param.requires_grad = False
        for param in model.model.upsampler_layers.parameters():
            param.requires_grad = False

    # Trainer
    print("===> Instantiate trainer")
    if cfg.trainer.strategy == "ddp":
        del cfg.trainer.strategy
        trainer = Trainer(
            **cfg.trainer,
            logger=tb_logger,
            gradient_clip_val=1.0,  
            gradient_clip_algorithm="norm",  
            callbacks=callbacks,
            strategy=DDPStrategy(find_unused_parameters=cfg.find_unused_parameters, timeout=datetime.timedelta(seconds=365 * 24 * 3600))
            # profiler='simple',
        )
    else:
        trainer = Trainer(
            **cfg.trainer,
            logger=tb_logger,
            callbacks=callbacks
        )
    
    # Compute FLOPs
    #dummy_data_loader = DataLoader(
     #       data_module.train,
      #      batch_size=1,
       #     shuffle=True,
        #    num_workers=cfg.num_workers,
         #   drop_last=True,
          #  pin_memory=True)
    #dummy_batch = next(iter(dummy_data_loader))
    #dummy_X = dummy_batch["input"]

    # Perform model forward pass to get model output
    #with torch.no_grad():
     #   model_output, _ = model.model(dummy_X)
    
    # Calculate FLOPs including model head
    #flops_backbone = FlopCountAnalysis(model.model, dummy_X)
    #flops_head = FlopCountAnalysis(model.model_head, model_output)
    #print("\n\nTotal FLOPs (Backbone):")
    #print(flop_count_table(flops_backbone))
    #print(f"\n\nTotal FLOPs (Model Head):\n")
    #print(flop_count_table(flops_head))
    #flops_total = flops_backbone.total() + flops_head.total()
    #print(f"\nTotal FLOPs (Backbone + Model Head): {flops_total}\n")
        
    # Train the model
    if cfg.training:
        print("===> Start training")
        trainer.fit(model, data_module)

    if cfg.final_validate:
        print("===> Start validation")
        trainer.validate(model, data_module)
    if cfg.final_test:
        print("===> Start testing")
        trainer.test(model, data_module)

    if not cfg.training:
        trainer.save_checkpoint(f"{checkpoint_dirpath}/last.ckpt")


@hydra.main(config_path="./config", config_name="defaults", version_base="1.1")
def run(cfg: DictConfig):
    print(f"PyTorch-Lightning Version: {pl.__version__}")
    print(OmegaConf.to_yaml(cfg, resolve=True))
    train(cfg)


if __name__ == "__main__":
    os.environ["HYDRA_FULL_ERROR"] = os.environ.get("HYDRA_FULL_ERROR", "1")
    run()


