import argparse
import ruamel_yaml as yaml
from pathlib import Path
import os
from transformers import logging
logging.set_verbosity_error()
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from pl_data import MMTSFMDatasetModule
from model import DualForecaster


class LitMMTSFM(LightningModule):
    def __init__(self, model_params, tokenizer, max_len, lr=1e-4, weight_decay=1e-2, warmup_ratio=0.02, model_version=''):
        super().__init__()
        self.model = DualForecaster(**model_params)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.lr = lr
        self.weight_decay = weight_decay
        self.warmup_ratio = warmup_ratio
        self.model_version = model_version
        self.textAug = model_params['textAug']
        self.addfuture = model_params['addfuture']
    
    def forward(self, ts_x, ts_y, text_input_h, text_input_f):
        return self.model(ts=ts_x, y=ts_y, text_h=text_input_h, text_f=text_input_f, return_loss=True, return_embeddings=False)

    def training_step(self, batch, batch_idx):
        ts_x, ts_y, text_h, text_f = batch
        text_input_h = self.tokenizer(text_h, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")
        text_input_f = self.tokenizer(text_f, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")
        if self.textAug:
            if self.addfuture:
                loss_forecast, loss_contrastive, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f)
                loss = loss_forecast + loss_contrastive
                # loss_forecast, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f) # w/o History Texts
                # loss = loss_forecast
                
                self.log('train_loss', loss, on_step=False, on_epoch=True)
                self.log('train_loss_forecast', loss_forecast, on_step=False, on_epoch=True)
                self.log('train_loss_contrastive', loss_contrastive, on_step=False, on_epoch=True) # w/o History Texts
            else:
                loss_forecast, loss_contrastive, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f)
                loss = loss_forecast + loss_contrastive
                # loss = loss_forecast # w/o History Contrast
                
                self.log('train_loss', loss, on_step=False, on_epoch=True)
                self.log('train_loss_forecast', loss_forecast, on_step=False, on_epoch=True)
                self.log('train_loss_contrastive', loss_contrastive, on_step=False, on_epoch=True)
        else:
            loss_forecast, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f)
            loss = loss_forecast
            
            self.log('train_loss', loss, on_step=False, on_epoch=True)
            self.log('train_loss_forecast', loss_forecast, on_step=False, on_epoch=True)
        
        # 获取当前学习率
        optimizer = self.optimizers()
        current_lr = optimizer.param_groups[0]['lr']

        # 记录学习率
        self.log('learning_rate', current_lr, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        ts_x, ts_y, text_h, text_f = batch
        text_input_h = self.tokenizer(text_h, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")
        text_input_f = self.tokenizer(text_f, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")
        if self.textAug:            
            if self.addfuture:
                loss_forecast, loss_contrastive, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f)
                loss = loss_forecast + loss_contrastive
                # loss_forecast, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f) # w/o History Texts
                # loss = loss_forecast
            
                self.log('val_loss', loss, on_step=False, on_epoch=True)
                self.log('val_loss_forecast', loss_forecast, on_step=False, on_epoch=True)
                self.log('val_loss_contrastive', loss_contrastive, on_step=False, on_epoch=True) # w/o History Texts
            else:
                loss_forecast, loss_contrastive, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f)
                loss = loss_forecast + loss_contrastive
                # loss = loss_forecast # w/o History Contrast
            
                self.log('val_loss', loss, on_step=False, on_epoch=True)
                self.log('val_loss_forecast', loss_forecast, on_step=False, on_epoch=True)
                self.log('val_loss_contrastive', loss_contrastive, on_step=False, on_epoch=True)
        else:
            loss_forecast, forecast_prob = self(ts_x, ts_y, text_input_h, text_input_f)
            loss = loss_forecast
        
            self.log('val_loss', loss, on_step=False, on_epoch=True)
            self.log('val_loss_forecast', loss_forecast, on_step=False, on_epoch=True)
        return loss

    def add_weight_decay(self, model, weight_decay=1e-5):
        decay = []
        no_decay = []
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue  # frozen weights
            if len(param.shape) == 1 or name.endswith(".bias"):
                no_decay.append(param)
            else:
                decay.append(param)
        return [
            {'params': no_decay, 'weight_decay': 0.},
            {'params': decay, 'weight_decay': weight_decay}]
    
    def configure_optimizers(self):
        parameters = self.add_weight_decay(self.model, self.weight_decay)
        weight_decay = 0.
        optimizer = torch.optim.Adam(parameters, lr=self.lr, weight_decay=weight_decay)
        
        max_steps = self.trainer.estimated_stepping_batches
        print(f"Trainer max_steps: {max_steps}")
        
        # warmup_steps = int(self.warmup_ratio * max_steps)
        warmup_steps = 32
        step_size = 4

        ### CosineAnnealing ###
        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                step_tensor = torch.tensor(step, dtype=torch.float32)
                warmup_steps_tensor = torch.tensor(warmup_steps, dtype=torch.float32)
                max_steps_tensor = torch.tensor(max_steps, dtype=torch.float32)
                cosine_decay = 0.5 * (1.0 + torch.cos(torch.pi * (step_tensor - warmup_steps_tensor) / (max_steps_tensor - warmup_steps_tensor)))
                return 0.1 + 0.9 * cosine_decay
        
        ### TGTSF (LinearAnnealing) ###
        def lr_lambda_TGForecaster(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                step_tensor = torch.tensor(step, dtype=torch.float32)
                warmup_steps_tensor = torch.tensor(warmup_steps, dtype=torch.float32)
                max_steps_tensor = torch.tensor(max_steps, dtype=torch.float32)
                linear_decay = 1 - (step_tensor - warmup_steps_tensor) / (max_steps_tensor - warmup_steps_tensor)
                return linear_decay
            
        ### MMTSFM ###
        def lr_lambda_MMTSFM(step):
            if step%step_size==0 and step <= warmup_steps :
                return (step // step_size) / (warmup_steps / step_size)
            else:
                return 1

        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_MMTSFM)
        # return [optimizer], [scheduler]
        return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "interval": "step"  # 指定按 step 更新学习率
                }
            }


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/forecast.yaml', help='Configuration of MMTSFM')
    parser.add_argument('--output_dir', default='./output/forecast')
    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

    model_version = 'test'

    model_params = {
        'dim': config['dim'],
        'num_vars': config['num_vars'],
        'ts_token_size': config['token_size'],
        'ts_output_size': config['output_size'],
        'unimodal_depth': config['unimodal_depth'],
        'multimodal_depth': config['multimodal_depth'],
        'text_encoder_layers': config['text_encoder_layers'],
        'max_len': config['max_len'],
        'text_dim': config['text_dim'],
        'num_text_queries': config['num_text_queries'],
        'heads': config['heads'],
        'dim_head': config['dim_head'],
        'ff_mult': config['ff_mult'],
        'dropout_rate_fcst': config['dropout_rate_fcst'],
        'dropout_rate_cons': config['dropout_rate_cons'],
        'textAug': config['textAug_flag'],
        'addfuture': config['addfuture_flag'],
        'text_encoder_frozen_flag': config['text_encoder_frozen_flag'],
        'forecast_loss_weight': config['forecast_loss_weight'],
        'contrastive_loss_weight': config['contrastive_loss_weight'],
    }
    lr = 1e-4
    weight_decay = 1e-2
    
    lit_model = LitMMTSFM(model_params=model_params, lr=lr, weight_decay=weight_decay)

    data_module = MMTSFMDatasetModule(config=config)

    logger = TensorBoardLogger('./output/forecast/MyLogs', name='MM_model')
    
    # 使用 ModelCheckpoint 保存最好的模型
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=f'./output/forecast/model_ckpts/{model_version}',
        filename='model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min',
    )
    
    # 检查是否有可用的 GPU 并自动选择
    accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
    devices = torch.cuda.device_count() if torch.cuda.is_available() else 1

    # 检查是否有可用的 GPU
    trainer = Trainer(
        # log_every_n_steps=2,
        max_epochs=5, 
        logger=logger, 
        accelerator=accelerator,
        devices=devices,
        # callbacks=[checkpoint_callback],
        # val_check_interval = 0.5,
    )
    trainer.fit(lit_model, data_module)
    
    # data_module.setup()
    # for batch_index, batch in enumerate(data_module.train_dataloader()):
    #     loss = lit_model.training_step(batch=batch, batch_idx=0)
    #     print(loss)
    #     break