from torch.utils.data import DataLoader, TensorDataset, random_split
import matplotlib.pyplot as plt
import numpy as np
import pickle
import time
import copy
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import torch_geometric
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from torch_geometric.utils import add_self_loops, degree, to_dense_batch
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
from sklearn.model_selection import train_test_split
import time

import utils 
from utils import new_empty_dict, apply_canon_to_pair, smush_for_MLP, apply_canon_just_to_data
from datasets import RotatedQM9Dataset, DetectRotatedQM9Dataset, AugmentedQM9Dataset, MetaAugmentedDataset, MetaDetectionDataset
import os
import wandb
from collections import defaultdict
from collections.abc import Iterable

# Set random seed for reproducibility
# torch.manual_seed(42)
def last_two_levels(path):
    return os.path.join(*os.path.normpath(path).split(os.sep)[-2:])

def is_iterable_but_not_tensor(obj):
    return isinstance(obj, Iterable) and not isinstance(obj, torch.Tensor)

def send_to_device(elt, device):
    # If the element is a DataBatch (or a Data object), move all of its tensors to the device
    ## @Hannah modifying this to handle DataBatch
    if isinstance(elt, torch_geometric.data.Data):
        return elt.to(device)
    elif isinstance(elt, list) or isinstance(elt, tuple):
        # Recursively handle lists or tuples of DataBatch or tensors
        return type(elt)(send_to_device(e, device) for e in elt)
    else:
        return elt.to(device) if isinstance(elt, torch.Tensor) else elt

def train_model(model, dataloaders, criterion, optimizer, dataset_cfg, save_dir=None, aux_criteria={'accuracy': utils.classification_loss}, num_epochs=50, use_wandb=False, device=torch.device('cpu'), print_every=20, verbose=True, train_viz_function=None, canon_model=None):
    # criterion is loss for backprop
    # aux_criteria are a dictionary of auxiliary loss functions to track between outputs and expected, but no backprop related to them. will compute average over batches. key = loss name, value = a criterion (ie takes in two arguments)

    start_time = time.time()
    train_losses, val_losses, test_losses = [], [], []
    train_aux_criteria, val_aux_criteria, test_aux_criteria = new_empty_dict(aux_criteria, value=[]), new_empty_dict(aux_criteria, value=[]), new_empty_dict(aux_criteria, value=[]) # dictionaries of empty lists
    dataset_type = dataset_cfg.name
    filter_mol = dataset_cfg.filter_mol
    num_train_batch = len(dataloaders['train'])
    for epoch in range(num_epochs):

        running_loss = 0.0
        running_aux_criteria = new_empty_dict(aux_criteria, value=0) 

        model.train()

        for ind, batch in enumerate(dataloaders['train']):
            data, expected = batch
            data, expected = send_to_device(data, device), send_to_device(expected, device) #data.to(device), expected.to(device)
            optimizer.zero_grad()

            if dataset_cfg.task == "task_dependent": # then canon_model should not be None
                # actually this behavior should vary depending on if it's the direct prediction setting or not??
                # check whether MetaPairAugmentedDataset comes up
                if dataset_cfg.task_dependent_args.binary_detection:
                    cx, y = apply_canon_to_pair(canon_model, data)
                    #print('average norm of cx per batch', torch.linalg.norm(cx) / cx.shape[0])
                    # if ind == 0:
                    #     breakpoint()
                    data = smush_for_MLP([cx, y])
                else:
                    # apply canon just to x
                    data = apply_canon_just_to_data(canon_model, data)
            outputs = model(data, dataset_type=dataset_type,filter_mol=filter_mol) # pass in all args; model's specific forward pass should ignore extra args via **kwargs

            loss = criterion(outputs, expected) 


            loss.backward()
            optimizer.step()
            
            #print([torch.norm(p).data for p in list(canon_model.parameters())])

            running_loss += loss.item()
            with torch.no_grad():
                for ac_name, aux_criterion in aux_criteria.items(): # compute auxiliary metrics for this batch, add to running total
                    aux_loss_value = aux_criterion(outputs, expected)
                    running_aux_criteria[ac_name] += aux_loss_value
                    # if aux_loss_value != 0:
                    #     print('ac_name', ac_name, 'aux_loss_value', aux_loss_value)

        if train_viz_function is not None:
            train_viz_function(model=model, dataloader=dataloaders['train'], savename=os.path.join(last_two_levels(save_dir), f'train_viz_epoch_{epoch}.pdf'))
        # Evaluate model
                    
        if 'val' in dataloaders.keys():
            epoch_val_loss, epoch_val_criteria = evaluate_model(model=model, val_loader=dataloaders['val'], criterion=criterion, aux_criteria=aux_criteria, dataset_cfg=dataset_cfg, device=device, verbose=verbose, canon_model=canon_model)
            val_losses.append(epoch_val_loss)

            # checkpoint for best val loss so far in case training doesn't finish
            if save_dir is not None and (epoch == 0 or epoch_val_loss < min(val_losses[:-1])):
                best_checkpoint_path = os.path.join(save_dir, 'model.pt')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': copy.deepcopy(model.state_dict()),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'results': {
                        'epoch_val_loss': epoch_val_loss,
                        'epoch': epoch
                    }
                }, best_checkpoint_path)

            if train_viz_function is not None:
                train_viz_function(model=model, dataloader=dataloaders['val'], savename=os.path.join(last_two_levels(save_dir), f'val_viz_epoch_{epoch}.pdf'))
        else:
            epoch_val_loss, epoch_val_criteria = 0, defaultdict(lambda: [0])

        if 'test' in dataloaders.keys():
            epoch_test_loss, epoch_test_criteria = evaluate_model(model=model, val_loader=dataloaders['test'], criterion=criterion, aux_criteria=aux_criteria, dataset_cfg=dataset_cfg, device=device, verbose=verbose, canon_model=canon_model)
            test_losses.append(epoch_test_loss)
        else:
            epoch_test_loss, epoch_test_criteria = None, None
            

        # Convert to epoch-level metrics
        #epoch_train_loss = running_loss / num_train_batch
        # report on eval set
        epoch_train_loss, epoch_train_criteria = evaluate_model(model=model, val_loader=dataloaders['train'], criterion=criterion, aux_criteria=aux_criteria, dataset_cfg=dataset_cfg, device=device, verbose=verbose, canon_model=canon_model)
        train_losses.append(epoch_train_loss)

        for ac_name, _ in aux_criteria.items():
            #epoch_train_criteria = running_aux_criteria[ac_name] / num_train_batch
            train_aux_criteria[ac_name].append(epoch_train_criteria[ac_name])
            val_aux_criteria[ac_name].append(epoch_val_criteria[ac_name])
            test_aux_criteria[ac_name].append(epoch_test_criteria[ac_name])
        
        if use_wandb:
            log_dict = {
                "epoch": epoch,
                "train_loss": epoch_train_loss,
                "val_loss": epoch_val_loss,
                "test_loss": epoch_test_loss,
            }

            for ac_name, _ in aux_criteria.items():
                log_dict[f"train_{ac_name}"] = train_aux_criteria[ac_name][-1]
                log_dict[f"val_{ac_name}"] = val_aux_criteria[ac_name][-1]
                log_dict[f"test_{ac_name}"] = test_aux_criteria[ac_name][-1]

            wandb.log(log_dict)
        

        if verbose and epoch % print_every == 0:
            print(f"Epoch {epoch + 1}, Train Loss: {epoch_train_loss:.2f}", end="")
            for i, (ac_name, _) in enumerate(aux_criteria.items()):
                val = train_aux_criteria[ac_name][-1]
                if type(val) == float:
                    print(f" Train {ac_name}: {val:.2f}", end="")
                else:
                    formatted_str = ", ".join(f"{x:.2f}" for x in val)
                    print(f" Train {ac_name}: {formatted_str}", end="")
            if 'val' in dataloaders.keys():
                print(f"\nValidation Loss: {epoch_val_loss:.2f}", end="")
                for i, (ac_name, _) in enumerate(aux_criteria.items()):
                    val = val_aux_criteria[ac_name][-1]
                    if type(val) == float:
                        print(f" Validation {ac_name}: {val:.2f}", end="")
                    else:
                        formatted_str = ", ".join(f"{x:.2f}" for x in val)
                        print(f" Validation {ac_name}: {formatted_str}", end="")
            
            if 'test' in dataloaders.keys():
                print(f"\Test Loss: {epoch_test_loss:.2f}", end="")
                for i, (ac_name, _) in enumerate(aux_criteria.items()):
                    val = test_aux_criteria[ac_name][-1]
                    if type(val) == float:
                        print(f" Test {ac_name}: {val:.2f}", end="")
                    else:
                        formatted_str = ", ".join(f"{x:.2f}" for x in val)
                        print(f" Test {ac_name}: {formatted_str}", end="")
                current_time = time.time()
                elapsed_time = (current_time - start_time)/60.0
                print(f'\n Elapsed: {elapsed_time:.2f} min')

    results = {'train_losses': train_losses, 'val_losses': val_losses, 'test_losses': test_losses}

    if 'val' in dataloaders.keys():
        lowest_val_loss_ind = int(torch.argmin(torch.tensor(val_losses)))
    else:
        lowest_val_loss_ind = -1
    results['lowest_val_loss_ind'] = lowest_val_loss_ind 
    results['best_test_loss'] = test_losses[lowest_val_loss_ind]
    
    for ac_name, _ in aux_criteria.items():
        results[f'train_{ac_name}'] = train_aux_criteria[ac_name]
        results[f'val_{ac_name}'] = val_aux_criteria[ac_name]
        results[f'test_{ac_name}'] = test_aux_criteria[ac_name]

        results[f'best_test_{ac_name}'] = test_aux_criteria[ac_name][lowest_val_loss_ind]
    if save_dir is not None:
        checkpoint_path = os.path.join(save_dir, f'model.pt')
        results['checkpoint_path'] = checkpoint_path

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            #'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'results': results
        }

        torch.save(checkpoint, checkpoint_path)

        # change this to save model + make symlinks periodically! based on a config, probably

    elapsed_time = (current_time - start_time)/60.0
    results['elapsed_time'] = elapsed_time
    
    return results

def evaluate_model(model, val_loader, criterion, aux_criteria, dataset_cfg, device=torch.device('cpu'), verbose=True, canon_model=None):
    # validation
    model.eval()
    dataset_type = dataset_cfg.name
    filter_mol = dataset_cfg.filter_mol 
    val_loss = 0.0
    epoch_val_criteria = new_empty_dict(aux_criteria, value=0) 
    num_val_batch = len(val_loader)
    with torch.no_grad():
        for batch in val_loader:
            data, expected = batch
            data, expected = send_to_device(data, device), send_to_device(expected, device) #data.to(device), expected.to(device)

            if dataset_cfg.task == "task_dependent": # then canon_model should not be None
                # actually this behavior should vary depending on if it's the direct prediction setting or not??
                # check whether MetaPairAugmentedDataset comes up
                if dataset_cfg.task_dependent_args.binary_detection:
                    cx, y = apply_canon_to_pair(canon_model, data)
                    data = smush_for_MLP([cx, y])
                else:
                    # apply canon just to x
                    data = apply_canon_just_to_data(canon_model, data)

            outputs = model(data, dataset_type=dataset_type,filter_mol=filter_mol)

            for ac_name, aux_criterion in aux_criteria.items():
                epoch_val_criteria[ac_name] += aux_criterion(outputs, expected)

            val_loss += criterion(outputs, expected).item() #.mean().item() # not sure why .mean().item() was here?

        epoch_val_loss = val_loss / num_val_batch

        for ac_name, _ in aux_criteria.items():
            epoch_val_criteria[ac_name] /= num_val_batch
        return epoch_val_loss, epoch_val_criteria
        