"""
Copyright 2025 [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os, time
import numpy as np
import torch
from torch import Tensor, LongTensor
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pack_padded_sequence, PackedSequence
from .nn import OUFlow



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def train(model: OUFlow, xs_train: PackedSequence, ts_train: PackedSequence, xs_val: PackedSequence, ts_val: PackedSequence, checkpoint_dir: str, log_path: str, train_params: dict):
    # optimizer
    paramnames_fast = ['decayrates', 'frequencies', 'prior.mean', 'prior.cov_choleskys', 'observer.W', 'observer._logvar', 'sys_cov_choleskys']
    # paramnames_fast = ['decayrates', 'frequencies', 'prior.mean', 'prior.cov_choleskys', 'observer.W']
    paramnames_mixture = ['mixture_weights']
    params_fast = [param for name, param in model.named_parameters() if any([pname in name for pname in paramnames_fast])]
    params_mixture = [param for name, param in model.named_parameters() if any([pname in name for pname in paramnames_mixture])]
    params = [param for name, param in model.named_parameters() if
              all([pname not in name for pname in paramnames_fast]) and
              all([pname not in name for pname in paramnames_mixture])]
    if train_params['optimizer']['method'] == 'Adam':
        optimizer = torch.optim.Adam([{'params': params}, {'params': params_fast, 'lr': train_params['optimizer']['lr_fast']}, {'params': params_mixture, 'lr': train_params['optimizer']['lr_mixture']}], lr=train_params['optimizer']['lr'], weight_decay=train_params['optimizer']['weight_decay'])
    elif train_params['optimizer']['method'] == 'AdamW':
        optimizer = torch.optim.AdamW([{'params': params}, {'params': params_fast, 'lr': train_params['optimizer']['lr_fast']}, {'params': params_mixture, 'lr': train_params['optimizer']['lr_mixture']}], lr=train_params['optimizer']['lr'], weight_decay=train_params['optimizer']['weight_decay'])
    else:
        raise NotImplementedError("The specified optimizer is not supported.")
    
    # scheduler
    if train_params['scheduler']['method'] == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=train_params['scheduler']['factor'], patience=train_params['scheduler']['patience'])
    else:
        raise NotImplementedError("The specified scheduler is not supported.")
    
    # loss function
    loss_train = Loss(model, xs_train, ts_train, **train_params['loss_train'])
    loss_val = Loss(model, xs_val, ts_val, **train_params['loss_val'])

    # training loop
    loss_best = torch.inf
    checkpoint_path = ''
    t0 = time.time()
    for epoch in range(1000000000):
        model.train()
        optimizer.zero_grad()
        loss_train(epoch).backward()
        if 'clip_grad' in train_params:
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_params['clip_grad'])
        optimizer.step()
        
        # validation
        with torch.no_grad():
            model.eval()
            loss_val(epoch)
        if 'start_epoch' not in train_params['scheduler'] or epoch > train_params['scheduler']['start_epoch']:
            scheduler.step(loss_val.loss)

        # save checkpoint
        if loss_val.loss < loss_best:
            patience_count = 0
            loss_best = loss_val.loss
            if checkpoint_path != '':
                os.remove(checkpoint_path)
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{epoch:06d}_{loss_best:.3e}.pt')
            torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'epoch': epoch
                    }, checkpoint_path)
        else:
            patience_count += 1

        # logging
        t1 = time.time()
        print(f'epoch={epoch}, elapsed_time={t1-t0:.1f}s, lr={optimizer.state_dict()["param_groups"][0]["lr"]:.3e}, loss_train=[{loss_train.losses[0]:.3e}, {loss_train.losses[1]:.3e}, {loss_train.losses[2]:.3e}], loss_val=[{loss_val.losses[0]:.3e}, {loss_val.losses[1]:.3e}, {loss_val.losses[2]:.3e}]')
        with open(log_path, 'a') as f:
            f.write(f'epoch={epoch}, elapsed_time={t1-t0:.1f}s, lr={optimizer.state_dict()["param_groups"][0]["lr"]:.3e}, loss_train=[{loss_train.losses[0]:.3e}, {loss_train.losses[1]:.3e}, {loss_train.losses[2]:.3e}], loss_val=[{loss_val.losses[0]:.3e}, {loss_val.losses[1]:.3e}, {loss_val.losses[2]:.3e}] | ')
            f.write('mixture_weights=[')
            for w in model.mixture_weights():
                f.write(f'{w:.3e}, ')
            f.write('] | decayrates: ')
            decayrates = model.propagator.decayrates()
            f.write(f'mean={decayrates.mean().item():.3e}, std={decayrates.std().item():.3e} | frequencies: ')
            frequencies = model.propagator.frequencies
            f.write(f'mean={frequencies.mean().item():.3e}, std={frequencies.std().item():.3e} | sys_cov_choleskys_diag: ')
            sys_cov_cholesky_diag = model.propagator.sys_cov_choleskys().diagonal(dim1=-2, dim2=-1)
            f.write(f'mean={sys_cov_cholesky_diag.mean().item():.3e}, std={sys_cov_cholesky_diag.std().item():.3e} | prior.means: ')
            means = model.prior.means
            f.write(f'mean={means.mean().item():.3e}, std={means.std().item():.3e} | prior.cov_choleskys_diag: ')
            cov_choleskys_diag = model.prior.cov_choleskys().diagonal(dim1=-2, dim2=-1)
            f.write(f'mean={cov_choleskys_diag.mean().item():.3e}, std={cov_choleskys_diag.std().item():.3e} | observer.W: ')
            W = model.observer.W
            f.write(f'mean={W.mean().item():.3e}, std={W.std().item():.3e} | observer.logvar: ')
            logvar = model.observer.logvar(torch.arange(model.num_mode))
            f.write(f'mean={logvar.mean().item():.3e}, std={logvar.std().item():.3e}\n')

        # stopping
        if train_params['stopping']['method'] == 'MaxEpoch':
            if epoch == train_params['stopping']['epoch']:
                break
        elif train_params['stopping']['method'] == 'Patience':
            if patience_count == train_params['stopping']['patience']:
                break
        else:
            raise NotImplementedError("The specified stopping method is not supported.")



class Loss():
    def __init__(self, model: OUFlow, xs: PackedSequence, ts: PackedSequence,
                 batch_size: int, forecast_interval: float,
                 multiplier_start: list=None, multiplier_end: list=None, anneal_start: list=None, anneal_end: list=None):
        self.model = model
        self.xs = xs
        self.ts = ts
        self.batch_size = batch_size
        self.forecast_interval = forecast_interval
        self.multiplier_start = multiplier_start
        self.multiplier_end = multiplier_end
        self.anneal_start = anneal_start
        self.anneal_end = anneal_end

        self.id_pack_to_pad_nofinal = [[[scenario, it] for it in range(n_timesteps-1)] for scenario, n_timesteps in enumerate(pad_packed_sequence(ts)[1])]
        self.id_pack_to_pad_nofinal = LongTensor([item for sublist in self.id_pack_to_pad_nofinal for item in sublist]).to(device)  # (sum of n_timesteps-1, 2)
        self.ts_padded = pad_packed_sequence(self.ts, batch_first=True, padding_value=torch.nan)[0]   # (scenario, n_time_max, 1)
        self.xs_padded = pad_packed_sequence(self.xs, batch_first=True, padding_value=0)[0]   # (scenario, n_time_max, dim)
        self.n_scenario = self.ts_padded.shape[0]
        self.n_time_max = self.ts_padded.shape[1]
        self.tmin = torch.Tensor(np.nanmin(self.ts_padded[:,:,0].cpu().numpy(), axis=-1)).to(dtype=self.ts_padded.dtype, device=self.ts_padded.device)  # (n_scenario)
        self.tmax = torch.Tensor(np.nanmax(self.ts_padded[:,:,0].cpu().numpy(), axis=-1)).to(dtype=self.ts_padded.dtype, device=self.ts_padded.device)  # (n_scenario)


    def __call__(self, epoch: int) -> Tensor:
        xs, ts = self.sample(self.batch_size//2, self.forecast_interval, random_initial=True, down_sample=True)
        xs0, ts0 = self.sample(self.batch_size//2, self.forecast_interval, random_initial=False, down_sample=True, need_initial=True)

        # concatenate the two batches
        xs_data, xs_lengths = pad_packed_sequence(xs, batch_first=True, padding_value=0)
        xs0_data, xs0_lengths = pad_packed_sequence(xs0, batch_first=True, padding_value=0)
        ts_data, _ = pad_packed_sequence(ts, batch_first=True, padding_value=0)
        ts0_data, _ = pad_packed_sequence(ts0, batch_first=True, padding_value=0)
        max_length = max(xs_lengths.max(), xs0_lengths.max())
        if xs_lengths.max() < max_length:
            xs_data = torch.cat([xs_data, torch.zeros(xs_data.shape[0], max_length-xs_data.shape[1], xs_data.shape[2], device=device)], dim=1)
            ts_data = torch.cat([ts_data, torch.zeros(ts_data.shape[0], max_length-ts_data.shape[1], ts_data.shape[2], device=device)], dim=1)
        if xs0_lengths.max() < max_length:
            xs0_data = torch.cat([xs0_data, torch.zeros(xs0_data.shape[0], max_length-xs0_data.shape[1], xs0_data.shape[2], device=device)], dim=1)
            ts0_data = torch.cat([ts0_data, torch.zeros(ts0_data.shape[0], max_length-ts0_data.shape[1], ts0_data.shape[2], device=device)], dim=1)
        xs_data = torch.cat([xs_data, xs0_data], dim=0)
        ts_data = torch.cat([ts_data, ts0_data], dim=0)
        lengths = torch.cat([xs_lengths, xs0_lengths], dim=0)
        xs = pack_padded_sequence(xs_data, lengths=lengths, batch_first=True, enforce_sorted=False)
        ts = pack_padded_sequence(ts_data, lengths, batch_first=True, enforce_sorted=False)
        
        nll, nll_mode_mean, responsibility_imbalance = self.model.losses(xs, ts)
        self.losses = [nll.mean(), nll_mode_mean.mean(), responsibility_imbalance]

        if self.multiplier_start is None or self.multiplier_end is None or self.anneal_start is None or self.anneal_end is None:
            self.loss = self.losses[0]
        else:
            multiplier_start = self.multiplier_start
            multiplier_end = self.multiplier_end
            anneal_start = self.anneal_start
            anneal_end = self.anneal_end
            self.loss = 0.0
            for i in range(len(multiplier_start)):
                multiplier = multiplier_start[i] if epoch < anneal_start[i] \
                    else multiplier_start[i] + (multiplier_end[i] - multiplier_start[i]) * (epoch - anneal_start[i]) / (anneal_end[i] - anneal_start[i]) \
                    if epoch < anneal_end[i] else multiplier_end[i]
                self.loss += self.losses[i] * multiplier

        return self.loss


    def sample(self, batch_size: int, forecast_interval: float, random_initial: bool=False, down_sample: bool=False, need_initial: bool=False) -> tuple[PackedSequence, PackedSequence]:
        id_scenario = torch.randint(0, self.n_scenario, [batch_size])   # (batch_size)
        if random_initial:
            t0 = torch.rand(batch_size, device=device) * (self.tmax - self.tmin + forecast_interval)[id_scenario] + (self.tmin - forecast_interval)[id_scenario]  # (batch_size)
        else:
            it = torch.randint(0, self.n_time_max-1, [batch_size])  # (batch_size)
            t0 = self.ts_padded[id_scenario, it, 0]    # (batch_size)
        is_in_forecast_interval = ((self.ts_padded[id_scenario] <= (t0 + forecast_interval)[:,None,None] + 1e-4) & (self.ts_padded[id_scenario] >= t0[:,None,None]))[:,:,0]  # (batch_size, n_time_max)

        if down_sample:
            mask = torch.randint(0, 2, (batch_size, self.n_time_max), dtype=torch.bool, device=device)    # (batch_size, n_time_max)
            if need_initial:
                mask[torch.where(self.ts_padded[id_scenario] == t0[:,None,None])[0:2]] = True
            is_in_forecast_interval = is_in_forecast_interval & mask    # (batch_size, n_time_max)

        # Remove scenarios with no data
        mask_no_data = is_in_forecast_interval.any(dim=1)   # (batch_size)
        is_in_forecast_interval = is_in_forecast_interval[mask_no_data]  # (batch_size_with_data, n_time_max)
        id_scenario = id_scenario[mask_no_data.cpu()]  # (batch_size_with_data)
        t0 = t0[mask_no_data]  # (batch_size_with_data)
        batch_size_with_data = id_scenario.shape[0]

        xs = pack_sequence([self.xs_padded[id_scenario[batch], torch.nonzero(is_in_forecast_interval[batch])][:,0] for batch in range(batch_size_with_data)], enforce_sorted=False)
        ts = pack_sequence([self.ts_padded[id_scenario[batch], torch.nonzero(is_in_forecast_interval[batch])][:,0] - t0[batch] for batch in range(batch_size_with_data)], enforce_sorted=False)

        return xs, ts