import torch
from utils import *
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


def train(train_dataloader,
          val_dataloader,
          transformer,
          model,
          loss_fn,
          optimizer,
          scheduler,
          local_rank,
          world_size,
          grad_clip,
          epoch,
          log_print_interval_epoch,
          model_save_interval_epoch,
          log_dir,
          checkpoint_dir):
    
    device = local_rank
    
    if local_rank == 0:
        writer = SummaryWriter(log_dir)
        logger = Logger(log_dir)
        checker = Checkpoint(checkpoint_dir, model, device)
        epoch_history = []
        lr_history = []
        train_loss_step_history = []
        train_loss_full_history = []
        val_loss_step_history = []
        val_loss_full_history = []
        print("Number of Model Parameters: {}".format(get_num_params(model)))
        print(model)
        print("Start Training...")
        logger.print("Number of Model Parameters: {}".format(get_num_params(model)))
        logger.print(model)
        logger.print("Start Training...")
    
    for i in range(epoch):
        torch.distributed.barrier()
        train_dataloader.sampler.set_epoch(i)
        train_loss_step = 0
        train_loss_full = 0

        for data in tqdm(train_dataloader):
            x, y1, y2 = data
            
            x = x.to(device)
            x = torch.reshape(x, (x.shape[0], -1, x.shape[-1]))
            y1 = y1.to(device)
            y1 = torch.reshape(y1, (y1.shape[0], -1, y1.shape[-1]))
            y2 = y2.to(device)
            y2 = torch.reshape(y2, (y2.shape[0], -1, y2.shape[-1]))

            model.train()
            
            T = 10
            step = 1
            loss = 0
            for t in range(0, T, step):
                gt = y2[..., t:t+step]
                pred_step = model(x, y1)
                loss += loss_fn(pred_step, gt)
                if t == 0:
                    pred_full = pred_step
                else:
                    pred_full = torch.cat((pred_full, pred_step), -1)
                y1 = torch.cat((y1[..., step:], gt), dim = -1)
            
            train_loss_step_temp = loss
            train_loss_full_temp = loss_fn(pred_full, y2)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()
            
            train_batch_loss_step = torch.tensor(train_loss_step_temp.item()).to(device)
            torch.distributed.all_reduce(train_batch_loss_step)
            train_loss_step = train_loss_step + train_batch_loss_step / world_size
            
            train_batch_loss_full = torch.tensor(train_loss_full_temp.item()).to(device)
            torch.distributed.all_reduce(train_batch_loss_full)
            train_loss_full = train_loss_full + train_batch_loss_full / world_size
        
        
        train_loss_step = (train_loss_step / len(train_dataloader)).item()
        train_loss_full = (train_loss_full / len(train_dataloader)).item()
        
        val_loss_step, val_loss_full = val(val_dataloader, transformer, model, loss_fn, local_rank, world_size)
        val_loss_step = val_loss_step.item()
        val_loss_full = val_loss_full.item()
        
        if local_rank == 0:
            if (i + 1) % log_print_interval_epoch == 0:
                writer.add_scalar("Learning Rate", optimizer.state_dict()['param_groups'][0]['lr'], i+1)
                writer.add_scalar("Train Loss Step", train_loss_step)
                writer.add_scalar("Train Loss Full", train_loss_full)
                
                writer.add_scalar("Val Loss Step", val_loss_step)
                writer.add_scalar("Val Loss Full", val_loss_full)
                
                epoch_history.append(i+1)
                lr_history.append(optimizer.state_dict()['param_groups'][0]['lr'])
                train_loss_step_history.append(train_loss_step)
                train_loss_full_history.append(train_loss_full)
                val_loss_step_history.append(val_loss_step)
                val_loss_full_history.append(val_loss_full)
                
                print("Epoch: {}\tLearning Rate :{}\tTrain Loss Step: {}\tTrain Loss Full: {}\tVal Loss Step: {}\tVal Loss Full: {}".format(i+1, optimizer.state_dict()['param_groups'][0]['lr'], train_loss_step, train_loss_full, val_loss_step, val_loss_full))
                logger.print("Epoch: {}\tLearning Rate :{}\tTrain Loss Step: {}\tTrain Loss Full: {}\tVal Loss Step: {}\tVal Loss Full: {}".format(i+1, optimizer.state_dict()['param_groups'][0]['lr'], train_loss_step, train_loss_full, val_loss_step, val_loss_full))
            
            if (i + 1) % model_save_interval_epoch == 0:
                checker.save(i+1)
    
    if local_rank == 0:
        writer.close()
        logger.save([epoch_history, lr_history, train_loss_step_history, train_loss_full_history, val_loss_step_history, val_loss_full_history], ["Epoch", "LR", "Train_Loss_Step", "Train_Loss_Full", "Val_Loss_Step", "Val_Loss_Full"])
        print("Finish Training !")
        logger.print("Finish Training !")


def val(val_dataloader,
        transformer,
        model,
        loss_fn,
        local_rank,
        world_size):
    
    with torch.no_grad():
        device = local_rank
        val_loss_step = 0
        val_loss_full = 0
        
        for data in tqdm(val_dataloader):
            x, y1, y2 = data
            
            x = x.to(device)
            x = torch.reshape(x, (x.shape[0], -1, x.shape[-1]))
            y1 = y1.to(device)
            y1 = torch.reshape(y1, (y1.shape[0], -1, y1.shape[-1]))
            y2 = y2.to(device)
            y2 = torch.reshape(y2, (y2.shape[0], -1, y2.shape[-1]))
            
            model.eval()
            
            T = 10
            step = 1
            loss = 0
            for t in range(0, T, step):
                gt = y2[..., t:t+step]
                pred_step = model(x, y1)
                loss += loss_fn(pred_step, gt)
                if t == 0:
                    pred_full = pred_step
                else:
                    pred_full = torch.cat((pred_full, pred_step), -1)
                y1 = torch.cat((y1[..., step:], pred_step), dim = -1)
            
            val_loss_step_temp = loss
            val_loss_full_temp = loss_fn(pred_full, y2)
            
            val_batch_loss_step = torch.tensor(val_loss_step_temp.item()).to(device)
            torch.distributed.all_reduce(val_batch_loss_step)
            val_loss_step = val_loss_step + val_batch_loss_step / world_size
            
            val_batch_loss_full = torch.tensor(val_loss_full_temp.item()).to(device)
            torch.distributed.all_reduce(val_batch_loss_full)
            val_loss_full = val_loss_full + val_batch_loss_full / world_size
        
        val_loss_step /= len(val_dataloader)
        val_loss_full /= len(val_dataloader)

    return val_loss_step, val_loss_full
