import os
import random
import sys
import time
import numpy as np
import torch
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import MinMaxScaler
from pickle import load
from utils import *
from puzzle_loader import *
from COMONet import *
from sklearn.metrics import (mean_absolute_error,
                            mean_absolute_percentage_error,
                            mean_squared_error, r2_score)


# parser 
parser = argparse.ArgumentParser(description='train_COMONet')   
parser.add_argument('--bs', help='batch_size',type=int)
parser.add_argument('--lr', help='learning_rate',type=float)
parser.add_argument('--epochs', help='epoch',type=int)   

args = parser.parse_args()   


def train(args):
        
    DATA_PATH_train = os.path.join(f'{"./"}{"train.csv"}')
    DATA_PATH_validation = os.path.join(f'{"./"}{"validation.csv"}')
    DATA_PATH_test = os.path.join(f'{"./"}{"test.csv"}')
    train_dataset = PuzzledataLoader(csv_path= DATA_PATH_train)
    validation_dataset = PuzzledataLoader(csv_path= DATA_PATH_validation)
    test_dataset = PuzzledataLoader(csv_path= DATA_PATH_test)
    print(f"Training Data Size : {len(train_dataset)}")
    print(f"Validation Data Size : {len(validation_dataset)}")
    print(f"Testing Data Size : {len(test_dataset)}")
    train_dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=False)
    validation_dataloader = DataLoader(validation_dataset, batch_size=len(validation_dataset), shuffle=True, drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True, drop_last=False)

    model = COMONet(input_size=3, 
                    conv_features = [],
                    monoconv_features = [],
                    conc_features = [0,2],
                    monoconc_features = [1],
                    mono_features= [],
                    conv_layer_size=(4,4, 4), 
                    monoconv_layer_size=(4,4, 4),
                    conc_layer_size=(64,64,32),
                    monoconc_layer_size=(64,64,32),
                    mono_layer_size=(4,4, 4),
                    unconst_layer_size=(4,4,4),
                    batch_norm=False,
                    activation = ['relu','relu-n'])
    


    
    #model
    # number of parameter
    param_amount = 0
    for p in model.named_parameters():
        #print(p[0], p[1].numel())
        param_amount += p[1].numel()
    print('total param amount:', param_amount)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #device = 'cpu'
    print(device)
    model.to(device)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr= float(args.lr))


    # model train 
    num_epochs = args.epochs
    total_batch = len(train_dataloader)
    model.train()
    min_mse = 100000

    start = time.time()
    for epoch in range(num_epochs):
        train_loss = 0.0
        for batch, (inputs, targets) in enumerate(train_dataloader):


            outputs = model(inputs.to(device))
            loss = criterion(outputs, targets.to(device))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        train_loss = train_loss/total_batch
        
        with torch.no_grad():

            for j,(inputs, targets) in enumerate(validation_dataloader):
                x = inputs.to(device)
                y = targets.to(device)
                outputs = model(x)
                validation_loss = criterion(outputs, targets.to(device))
    
                load_scaler = load(open(f'{"./"}{"scaler.pkl"}', 'rb'))
                pred_y = outputs*(load_scaler.data_max_[3] - load_scaler.data_min_[3]) + load_scaler.data_min_[3]
                true_y = targets*(load_scaler.data_max_[3] - load_scaler.data_min_[3]) + load_scaler.data_min_[3]

                mse = criterion(pred_y, true_y.to(device))
                #if mse <= min_mse:
                torch.save(model, "./model.pth") 
                    
                min_mse = min(min_mse, mse)
                end = time.time()
                elapsed = end - start       
        if (epoch+1) % 10 == 0:
            print('Epoch [{}/{}], Batch [{}/{}], Training Loss: {:.5f}, Validation Loss: {:.5f}, Validation mse: {:.5f}, Elapsed time (s): {:.4f}'.format(epoch + 1, num_epochs, batch + 1, total_batch, train_loss, validation_loss.item(), min_mse, elapsed))






if __name__ == "__main__":
    train(args)
