import torch
from utils import *
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


def train(train_dataloader,
          val_dataloader,
          transformer,
          model,
          loss_fn,
          optimizer,
          scheduler,
          local_rank,
          world_size,
          grad_clip,
          epoch,
          log_print_interval_epoch,
          model_save_interval_epoch,
          log_dir,
          checkpoint_dir):
    
    device = local_rank
    
    if local_rank == 0:
        writer = SummaryWriter(log_dir)
        logger = Logger(log_dir)
        checker = Checkpoint(checkpoint_dir, model, device)
        epoch_history = []
        lr_history = []
        train_loss_history = []
        val_loss_history = []
        train_val_loss_history = []
        print("Number of Model Parameters: {}".format(get_num_params(model)))
        print(model)
        print("Start Training...")
        logger.print("Number of Model Parameters: {}".format(get_num_params(model)))
        logger.print(model)
        logger.print("Start Training...")
    
    for i in range(epoch):
        torch.distributed.barrier()
        train_dataloader.sampler.set_epoch(i)
        train_loss = 0

        for data in tqdm(train_dataloader):
            x, y1, y2 = data
            
            x = x.to(device)
            x = torch.reshape(x, (x.shape[0], -1, x.shape[-1]))
            y1 = y1.to(device)
            y1 = torch.reshape(y1, (y1.shape[0], -1, y1.shape[-1]))
            y2 = y2.to(device)
            y2 = torch.reshape(y2, (y2.shape[0], -1, y2.shape[-1]))

            model.train()
            
            res = model(y1)
            loss = loss_fn(res, y2)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            scheduler.step()
            
            train_batch_loss = torch.tensor(loss.item()).to(device)
            torch.distributed.all_reduce(train_batch_loss)
            train_loss = train_loss + train_batch_loss / world_size
        

        train_loss = train_loss / len(train_dataloader)
        train_loss = train_loss.item()
        val_loss = val(val_dataloader, transformer, model, loss_fn, local_rank, world_size)
        val_loss = val_loss.item()
        train_val_loss = val(train_dataloader, transformer, model, loss_fn, local_rank, world_size)
        train_val_loss = train_val_loss.item()
        
        if local_rank == 0:
            if (i + 1) % log_print_interval_epoch == 0:
                writer.add_scalar("Learning Rate", optimizer.state_dict()['param_groups'][0]['lr'], i+1)
                writer.add_scalar("Train Loss", train_loss)
                writer.add_scalar("Val Loss", val_loss)
                writer.add_scalar("Train Val Loss", train_val_loss)
                
                epoch_history.append(i+1)
                lr_history.append(optimizer.state_dict()['param_groups'][0]['lr'])
                train_loss_history.append(train_loss)
                val_loss_history.append(val_loss)
                train_val_loss_history.append(train_val_loss)
                
                print("Epoch: {}\tLearning Rate :{}\tTrain Loss: {}\tVal Loss: {}\tTrain Val Loss: {}".format(i+1, optimizer.state_dict()['param_groups'][0]['lr'], train_loss, val_loss, train_val_loss))
                logger.print("Epoch: {}\tLearning Rate :{}\tTrain Loss: {}\tVal Loss: {}\tTrain Val Loss: {}".format(i+1, optimizer.state_dict()['param_groups'][0]['lr'], train_loss, val_loss, train_val_loss))
            
            if (i + 1) % model_save_interval_epoch == 0:
                checker.save(i+1)
    
    if local_rank == 0:
        writer.close()
        logger.save([epoch_history, lr_history, train_loss_history, val_loss_history, train_val_loss_history], ["Epoch", "LR", "Train_Loss", "Val_Loss", "Train_val_loss"])
        print("Finish Training !")
        logger.print("Finish Training !")


def val(val_dataloader,
        transformer,
        model,
        loss_fn,
        local_rank,
        world_size):
    
    with torch.no_grad():
        device = local_rank
        val_loss = 0
        
        for data in tqdm(val_dataloader):
            x, y1, y2 = data
            
            x = x.to(device)
            x = torch.reshape(x, (x.shape[0], -1, x.shape[-1]))
            y1 = y1.to(device)
            y1 = torch.reshape(y1, (y1.shape[0], -1, y1.shape[-1]))
            y2 = y2.to(device)
            y2 = torch.reshape(y2, (y2.shape[0], -1, y2.shape[-1]))
            
            model.eval()
            
            res = model(y1)
            # res = transformer.apply_y2(res, inverse=True)
            # y2 = transformer.apply_y2(y2, inverse=True)
            
            loss = loss_fn(res, y2)
            
            val_batch_loss = torch.tensor(loss.item()).to(device)
            torch.distributed.all_reduce(val_batch_loss)
            val_loss = val_loss + val_batch_loss / world_size
        
        val_loss /= len(val_dataloader)

    return val_loss


def test(test_dataloader,
        transformer,
        model,
        local_rank,
        world_size):
    
    with torch.no_grad():
        device = local_rank
        val_loss = 0
        
        for data in tqdm(test_dataloader):
            x, y1, y2 = data
            
            x = x.to(device)
            x = torch.reshape(x, (x.shape[0], -1, x.shape[-1]))
            y1 = y1.to(device)
            y1 = torch.reshape(y1, (y1.shape[0], -1, y1.shape[-1]))
            y2 = y2.to(device)
            y2 = torch.reshape(y2, (y2.shape[0], -1, y2.shape[-1]))
            
            model.eval()
            res = model(y1)
            res = transformer.apply_y2(res, inverse=True)
            y2 = transformer.apply_y2(y2, inverse=True)
            res = torch.reshape(res, (res.shape[0], 64, 64, 10))
            y2 = torch.reshape(y2, (y2.shape[0], 64, 64, 10))
            
            for t in range(0, 10):
                show_NS(res[0,:,:,t].cpu().numpy(), "../_misc/res_train_{}.png".format(t+10), 64, 64)
                show_NS(y2[0,:,:,t].cpu().numpy(), "../_misc/y2_train_{}.png".format(t+10), 64, 64)
                show_NS((res[0,:,:,t]-y2[0,:,:,t]).cpu().numpy(), "../_misc/error_{}.png".format(t+10), 64, 64)

            exit()
    return None
