import os
import torch
import hydra

from ema_pytorch import EMA
from omegaconf import DictConfig
from lightning.fabric import Fabric
from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter #tensorboard --logdir=runs --port=6010
from lightning.fabric.strategies import DDPStrategy

from dataset import get_data
from model.transform import *
from training import get_train
from model import trainer_model
from choice import chocie_function
from util import safe_state, IOStream

@hydra.main(version_base=None, config_path='configs', config_name=f'branch')
def main(cfg: DictConfig):
    cfg = chocie_function(cfg)
    torch.set_float32_matmul_precision('high')
    fabric = Fabric(accelerator="cuda", devices=cfg.general.num_devices, strategy=DDPStrategy(process_group_backend='nccl'))
    fabric.launch()
    device = safe_state(cfg)

    #model
    bg_feature = torch.tensor([0.,] * (cfg.model.feature_map_num + 3)).to(fabric.device)
    Trainer = trainer_model(cfg, bg_feature)
    # Trainer = Trainer.to(memory_format=torch.channels_last)
    ema = EMA(Trainer, beta=0.999, update_every=10, update_after_step=100)
    
    if cfg.opt.name == 'SGD':
        print("Use SGD")
        optimizer = torch.optim.SGD(Trainer.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001, nesterov=True)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 80], gamma=0.5)

        # optimizer = torch.optim.SGD(Trainer.parameters(), lr=cfg.opt.base_lr, momentum=cfg.opt.momentum, weight_decay=cfg.opt.weight_decay)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.opt.step, gamma=cfg.opt.gamma)
        
        # optimizer = optim.SGD(Trainer.parameters(), lr=cfg.opt.lr, momentum=cfg.opt.momentum, weight_decay=cfg.opt.weight_decay)
        # scheduler = CosineAnnealingLR(optimizer, cfg.opt.epochs, eta_min=cfg.opt.lr/100)
    if cfg.opt.name == 'Adam':
        print("Use Adam")
        optimizer = torch.optim.AdamW(Trainer.parameters(), lr=cfg.opt.lr, weight_decay=cfg.opt.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.opt.epochs, eta_min=cfg.opt.lr_min)

    #data
    train_3dgs = get_data(cfg, "train",transform=[Shuffle(),
                                        AttrNormalize(), 
                                        RandomScalePointCloud(),
                                        RandomShiftPointCloud(),
                                        RandomPointDropout(0.5),
                                        JitterPointCloud(),
                                        ])
    test_3dgs = get_data(cfg, "test",transform=[Shuffle(),
                                        AttrNormalize(), 
                                        RandomScalePointCloud(),
                                        RandomShiftPointCloud(),
                                        RandomPointDropout(0.5),
                                        JitterPointCloud(),
                                        ])
    train_dataloader_3dgs = DataLoader(train_3dgs, batch_size=cfg.opt.batch_size, shuffle=True, drop_last=True, num_workers=0)
    test_dataloader_3dgs = DataLoader(test_3dgs, batch_size=cfg.opt.batch_size, shuffle=False, drop_last=False, num_workers=0)
    
    ema = fabric.to_device(ema)
    Trainer, optimizer = fabric.setup(Trainer, optimizer)
    train_dataloader = fabric.setup_dataloaders(train_dataloader_3dgs)
    test_dataloader = fabric.setup_dataloaders(test_dataloader_3dgs)
    print('finish load')

    #initialize train
    first_iter = 0
    first_iter += 1
    if fabric.is_global_zero:
        if not os.path.exists('checkpoints/' + cfg.model.name):
            os.makedirs('checkpoints/' + cfg.model.name)
    io = IOStream('checkpoints/' + cfg.model.name + '/train_random.log')
    io.cprint(str(cfg))
    writer = SummaryWriter()
    print('begin train')

    #train course
    get_train(cfg,fabric,scheduler,ema,Trainer,optimizer,io,writer,train_dataloader,test_dataloader)

    

if __name__ == "__main__":
    main()