from models.rangemodel import LightningRangeModel
from data_utils.datasets import OccuranceDataModule
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import ModelCheckpoint
from utils.seed import seed_everything
import rasterio
from rasterio.resample import Resampling
import os
from config import cfg
import fire


def train(expt_name):
    seed_everything(cfg.train.seed)
    if cfg.data.env_cos:
        training_module = OccuranceDataModule(
            cfg.data.text_embeddings_path,
            cfg.data.train_parquet_path,
            env_cov_path=cfg.data.env_cov_path,
            bins=cfg.model.img_size,
            batch_size=cfg.train.batch_size,
            shuffle=cfg.train.shuffle,
            num_workers=cfg.train.num_workers,
        )
    else:
        training_module = OccuranceDataModule(
            cfg.data.text_embeddings_path,
            cfg.data.train_parquet_path,
            bins=cfg.model.img_size,
            batch_size=cfg.train.batch_size,
            shuffle=cfg.train.shuffle,
            num_workers=cfg.train.num_workers,
        )

    mask = torch.tensor(
        rasterio.open(cfg.data.mask_path).read(
            out_shape=(1, 900, 1800), resampling=Resampling.nearest
        )
        == 1
    )

    if cfg.data.llm_type == "Llama-2-7b-hf":
        text_dim = 4096
    elif cfg.data.llm_type == "Llama-2-13b-hf":
        text_dim = 5120
    elif cfg.data.llm_type == "Llama-2-70b-hf":
        text_dim = 8192
    else:
        raise NotImplementedError

    range_model = LightningRangeModel(
        filter_type=cfg.model.filter_type,
        spectral_transform=cfg.model.transform,
        operator_type=cfg.model.operator_type,
        img_size=cfg.model.img_size,
        scale_factor=cfg.model.scale_factor,
        in_chans=cfg.model.in_chans,
        out_chans=cfg.model.out_chans,
        embed_dim=cfg.model.embed_dim,
        num_layers=cfg.model.num_layers,
        encoder_layers=cfg.model.encoder_layers,
        spectral_layers=cfg.model.spectral_layers,
        env_cov=cfg.data.env_cov,
        attn_heads=cfg.model.attn_heads,
        attn_dim_head=cfg.model.attn_dim_head,
        text_dim=text_dim,
        loss_type=cfg.loss.type,
        gamma_neg=cfg.loss.gamma_neg,
        gamma_pos=cfg.loss.gamma_pos,
        alpha=cfg.loss.alpha,
        mask=mask,
        lr=cfg.train.lr,
    )

    checkpoint = ModelCheckpoint(
        dirpath=os.path.join(cfg.checkpoint.dirpath, expt_name),
        filename="{epoch}-{train_loss:.2f}",
        every_n_epochs=cfg.checkpoint.freq,
        save_last=True,
    )

    if cfg.train.device == "cuda":
        trainer = pl.Trainer(
            accelerator="gpu",
            devices=cfg.train.devices,
            strategy="ddp_find_unused_parameters_false",
            max_epochs=cfg.train.num_epochs,
            num_nodes=1,
            accumulate_grad_batches=cfg.train.accumulate_grad_batches,
            callbacks=[checkpoint],
        )
    else:
        trainer = pl.Trainer(
            max_epochs=cfg.train.num_epochs,
            num_nodes=1,
            accumulate_grad_batches=cfg.train.accumulate_grad_batches,
            callbacks=[checkpoint],
        )

    trainer.fit(range_model, training_module)

    # Save config file for reproducibility
    import json

    with open(os.path.join(cfg.checkpoint.dirpath, expt_name, "config.json"), "w") as f:
        json.dump(cfg, f)


if __name__ == "__main__":
    fire.Fire(train)
