import torch
import numpy as np
import os
from copy import deepcopy
from model_utils import *

def train_DISR(
    train_loader, test_loader, num_epochs, device, log_interval, save_interval, save_dir, generator, regressor, lr, w_sindy, sindy_reg_type, w_sindy_reg, w_sym_reg, st_freq, threshold, pde, **kwargs
):
    optimizer = torch.optim.Adam(regressor.parameters(), lr=lr)
    symm_loss = make_symmreg_pttrain(generator)
    log_items = ['loss_sindy', 'loss_sindy_reg', 'loss_sym_reg']

    for epoch in range(num_epochs):
        running_losses = {k: [] for k in log_items}
        regressor.train()
        for i, x in enumerate(train_loader):
            for key in x.keys():
                x[key] = x[key].to(device)
            loss = 0.0
            loss_sindy = w_sindy * regressor.loss(x)
            running_losses['loss_sindy'].append(loss_sindy.item())
            loss_sym_reg = symm_loss(x=x, pde=pde, f=partial(regressor.F, pde=pde))
            running_losses['loss_sym_reg'] = loss_sym_reg.item()
            loss = loss + w_sindy * loss_sindy + w_sym_reg * loss_sym_reg
            if sindy_reg_type == 'l1':
                loss_sindy_reg = sum([torch.norm(p, 1) for p in regressor.parameters()])
                running_losses['loss_sindy_reg'].append(loss_sindy_reg.item())
                loss = loss + w_sindy_reg * loss_sindy_reg
            else:
                raise ValueError(f'Unknown regularization type: {sindy_reg_type}')
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if st_freq > 0 and (epoch + 1) % st_freq == 0:
            regressor.set_threshold(threshold)

        if (epoch + 1) % log_interval == 0:
            print(', '.join([ f'Epoch {epoch}' ] + 
                [f'{k}: {np.mean(running_losses[k]):.4f}' for k in log_items]
            ))
            with torch.no_grad():
                running_losses = {'test_loss': []}
                for i, x in enumerate(test_loader):
                    for key in x.keys():
                        x[key] = x[key].to(device)
                    loss = regressor.loss(x)
                    running_losses['test_loss'].append(loss.item())
                    
                print(', '.join([ f'Epoch {epoch}' ] + 
                    [f'{k}: {np.mean(running_losses[k]):.4f}' for k in running_losses]
                ))
                if kwargs['print_eq']:
                    regressor.print()
        
        if (epoch + 1) % save_interval == 0:
            if not os.path.exists(f'saved_models/{save_dir}'):
                os.makedirs(f'saved_models/{save_dir}')
            torch.save(regressor.state_dict(), f'saved_models/{save_dir}/regressor_{epoch}.pt')

def train_DISR_lbfgs(
    train_loader, test_loader, num_epochs, device, log_interval, save_interval, save_dir, generator, regressor, lr, w_sindy, sindy_reg_type, w_sindy_reg, w_sym_reg, st_freq, threshold, pde, **kwargs
):
    train_data = next(iter(train_loader))
    for key in train_data.keys():
        train_data[key] = train_data[key].to(device)
    optimizer = torch.optim.LBFGS(regressor.parameters(), lr=lr)
    symm_loss = make_symmreg_pttrain(generator)
    losses = {}
    prev_params = [p.detach().clone() for p in regressor.parameters()]
    pprev_params = [p.detach().clone() for p in regressor.parameters()]
    tol = 1e-3

    def closure():
        optimizer.zero_grad()
        loss_sindy = regressor.loss(train_data)
        losses['loss_sindy'] = loss_sindy.item()
        if w_sym_reg > 0.0:
            loss_sym_reg = symm_loss(x=train_data, pde=pde, f=partial(regressor.F, pde=pde))
            losses['loss_sym_reg'] = loss_sym_reg.item()
        else:
            loss_sym_reg = 0.0
        loss = w_sindy * loss_sindy + w_sym_reg * loss_sym_reg
        if sindy_reg_type == 'l1':
            loss_sindy_reg = sum([torch.norm(p, 1) for p in regressor.parameters()])
            losses['loss_sindy_reg'] = loss_sindy_reg.item()
            loss = loss + w_sindy_reg * loss_sindy_reg
        elif sindy_reg_type == 'none':
            pass
        else:
            raise ValueError(f'Unknown regularization type: {sindy_reg_type}')
        
        loss.backward()
        return loss
    
    n_iters = 0
    for epoch in range(num_epochs):
        n_iters += 1
        optimizer.step(closure)
        if any(torch.isnan(p).any() for p in regressor.parameters()):
            print(f'NaN encountered at iteration {epoch}; exit training.')
            break
        with torch.no_grad():
            param_update_norm = sum(
                torch.norm(p - p_prev) for p, p_prev in zip(regressor.parameters(), prev_params)
            )
        if param_update_norm < tol:
            param_update_norm_2 = sum(
                torch.norm(p - p_prev) for p, p_prev in zip(regressor.parameters(), pprev_params)
            )
            if param_update_norm_2 < tol:
                print(f'Final convergence reached at iteration {epoch}; exit training.')
                if not os.path.exists(f'saved_models/{save_dir}'):
                    os.makedirs(f'saved_models/{save_dir}')
                torch.save(regressor.state_dict(), f'saved_models/{save_dir}/regressor_{epoch}.pt')
                break
            n_iters = 0
            regressor.set_threshold(threshold)
            optimizer = torch.optim.LBFGS(regressor.parameters(), lr=lr)
            pprev_params = [p.detach().clone() for p in regressor.parameters()]
            print(f'Convergence reached at iteration {epoch}; apply parameter thresholding and reset optimizer.')
        elif st_freq > 0 and n_iters % st_freq == 0:
            n_iters = 0
            regressor.set_threshold(threshold)
            optimizer = torch.optim.LBFGS(regressor.parameters(), lr=lr)
            print(f'Max number of LBFGS iterations reached; apply parameter thresholding and reset optimizer.')
        prev_params = [p.detach().clone() for p in regressor.parameters()]
        
        if (epoch + 1) % log_interval == 0:
            print(', '.join([f'Epoch {epoch}'] + 
                [f'{k}: {losses[k]:.4f}' for k in losses]
            ))
            running_losses = {'test_loss': []}
            for i, x in enumerate(test_loader):
                for key in x:
                    x[key] = x[key].to(device)
                loss = regressor.loss(x)
                running_losses['test_loss'].append(loss.item())
            print(', '.join([ f'Epoch {epoch}' ] + 
                [f'{k}: {np.mean(running_losses[k]):.4f}' for k in running_losses]
            ))
            if kwargs['print_eq']:
                regressor.print()

        if (epoch + 1) % save_interval == 0:
            if not os.path.exists(f'saved_models/{save_dir}'):
                os.makedirs(f'saved_models/{save_dir}')
            torch.save(regressor.state_dict(), f'saved_models/{save_dir}/regressor_{epoch}.pt')
