import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import argparse
import learner
import os
import numpy as np
import time
import torch.autograd.profiler as profiler
from tqdm import tqdm
import pdb

parser = argparse.ArgumentParser(description='main')
parser.add_argument('--log2_hdim', default=4, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--idim', default=16, type=int)
parser.add_argument('--odim', default=2, type=int)
parser.add_argument('--log2_num_layers', default=0, type=int)
parser.add_argument('--log2_batchsize', default=8, type=int)
parser.add_argument('--log2_validbatchsize', default=12, type=int)
parser.add_argument('--log2_testbatchsize', default=12, type=int)
parser.add_argument('--log2_trainsize', default=18, type=int)
parser.add_argument('--log2_validsize', default=12, type=int)
parser.add_argument('--log2_testsize', default=12, type=int)
parser.add_argument('--max_epochs', default=1000, type=int)
parser.add_argument('--no_improve_epochs', default=50, type=int)
parser.add_argument('--first_inst', default=0, type=int)
parser.add_argument('--num_inst', default=8, type=int)


args = parser.parse_args()

####
args.log2_hdim = args.log2_num_layers + 4
####

args.num_layers = 2**args.log2_num_layers
args.hdim= 2**args.log2_hdim
args.batchsize= 2**args.log2_batchsize
args.validbatchsize= 2**args.log2_validbatchsize
args.testbatchsize= 2**args.log2_testbatchsize
args.trainsize= 2**args.log2_trainsize
args.validsize= 2**args.log2_validsize
args.testsize= 2**args.log2_testsize

torch.manual_seed(42)
# Set device
device = torch.device(f"cuda:{args.gpu}" if args.gpu>=0 else "cpu")

# Load the TensorDataset
datasets = torch.load('./data/wavy_data.pt')
inputs_train,labels_train,sobs_train = datasets["train"][0][: args.trainsize].to(device), datasets["train"][1][: args.trainsize].to(device), datasets["train"][2][: args.trainsize].to(device)
dataloader_train = DataLoader(TensorDataset(inputs_train,labels_train,sobs_train), batch_size=args.batchsize, shuffle=True)  

sob_values = [datasets["valid"][2][i].item() for i in range(args.validsize)]
sorted_indices = np.argsort(sob_values)
inputs_valid,labels_valid,sobs_valid  = datasets["valid"][0][sorted_indices ].to(device), datasets["valid"][1][sorted_indices ].to(device), datasets["valid"][2][sorted_indices ].to(device)
dataloader_valid = DataLoader(TensorDataset(inputs_valid,labels_valid,sobs_valid), batch_size=args.validbatchsize, shuffle=False)  

sob_values = [datasets["test"][2][i].item() for i in range(args.testsize)]
sorted_indices = np.argsort(sob_values)
inputs_test,labels_test,sobs_test = datasets["test"][0][sorted_indices ].to(device), datasets["test"][1][sorted_indices ].to(device), datasets["test"][2][sorted_indices ].to(device)
dataloader_test = DataLoader(TensorDataset(inputs_test,labels_test,sobs_test ), batch_size=args.testbatchsize, shuffle=False)

# Initialize the model, loss function, and optimizer
model = learner.Learner(idim=args.idim,hdim=args.hdim,num_layers=args.num_layers).to(device)
criterion = nn.MSELoss()  # Use appropriate loss function for your problem
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) # Adjust learning rate as needed

# Function to compute validation loss
def compute_validation_loss(model, dataloader):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        Outputs = []
        Labels = []
        Sobs = []
        for inputs, labels, sobs in dataloader:
            inputs, labels, sobs = inputs.to(device), labels.to(device), sobs.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            Outputs.append(outputs)
            Labels.append(labels)
            Sobs.append(sobs)
        all_outputs = torch.cat(Outputs,dim=0)
        all_labels = torch.cat(Labels,dim=0)
        all_sobs = torch.cat(Sobs,dim=0)
    return total_loss / len(dataloader), torch.cat([(all_outputs-all_labels).norm(dim=1).unsqueeze(1),all_sobs,all_outputs,all_labels],dim=1)

# Training function
def train_model(model, dataloader_train, dataloader_valid, dataloader_test, criterion, optimizer, max_epochs, no_improve_epochs, inst):
    device = next(model.parameters()).device  # Get device of the model
    initial_valid_loss, _ = compute_validation_loss(model, dataloader_valid)
    print(f"Initial Validation Loss: {initial_valid_loss:.4f}")
    
    outdir = f"./results/results_log2layers_{args.log2_num_layers}_log2hdim_{args.log2_hdim}_log2trainsize_{args.log2_trainsize}"
    os.makedirs(outdir, exist_ok=True)
    if os.path.exists(f"{outdir}/results_inst_{inst}.pt"):
        return
    best_loss = 1e4
    best_epoch = -1
    for epoch in tqdm(range(max_epochs)):
        if epoch > best_epoch + args.no_improve_epochs:
            print(f"No valid_loss improvements in {args.no_improve_epochs} epochs. Training is stopped.")
            break
        model.train()
        running_loss = 0.0
        for inputs, labels, _ in tqdm(dataloader_train):
            optimizer.zero_grad()
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_train_loss = running_loss / len(dataloader_train)
        print(f"Epoch [{epoch+1}/{max_epochs}], Training Loss: {avg_train_loss:.4f}")
        
        # Compute and print validation loss
        valid_loss, outputs_tuple = compute_validation_loss(model, dataloader_valid)
        print(f"Epoch [{epoch+1}/{max_epochs}], Validation Loss: {valid_loss:.4f}")
        if valid_loss < best_loss:
            best_loss = valid_loss
            best_epoch = epoch
            best_model = model

    final_test_loss, test_results = compute_validation_loss(best_model, dataloader_test)
    print(f"Final test Loss: {final_test_loss:.4f}")
    torch.save(test_results, f"{outdir}/results_inst_{inst}.pt") # (n,3) tensor : pred, target, sob

for inst in range(args.first_inst,args.first_inst+args.num_inst):
    torch.manual_seed(42 + inst)
    model = learner.Learner(idim=args.idim,odim=args.odim,hdim=args.hdim,num_layers=args.num_layers).to(device)
    criterion = nn.MSELoss()  # Use appropriate loss function for your problem
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) # Adjust learning rate as needed
    train_model(model, dataloader_train, dataloader_valid, dataloader_test, criterion, optimizer, args.max_epochs, args.no_improve_epochs, inst)


